A switch on Lua strings

In one of the projects I'm currently working on, Lua is used a lot as a data description language. As a result, there are many occasions where I'd like to run a switch statement work over a Lua string. One example of this is the following function:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
  switch(lua_tostring(L, -1))
  {
  case "Graphic":   return new GraphicArt(L);
  case "Rectangle": return new RectangleArt(L);
  case "Text":      return new TextArt(L);
  case "Line":      return new LineArt(L);
  case "Swf":       return new SwfArt(L);
  }
  return new UnknownArt(L);
}

To my mind, this code perfectly expresses the idea of performing different actions based on different strings from Lua. The only problem is that this code won't work. A common translation into code which does actually work is the following:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
  if(const char* s = lua_tostring(L, -1))
  {
    if(!strcmp(s, "Graphic"))   return new GraphicArt(L);
    if(!strcmp(s, "Rectangle")) return new RectangleArt(L);
    if(!strcmp(s, "Text"))      return new TextArt(L);
    if(!strcmp(s, "Line"))      return new LineArt(L);
    if(!strcmp(s, "Swf"))       return new SwfArt(L);
  }
  return new UnknownArt(L);
}

This translation isn't perfect, as a string like "Line\0Stuff" will result in a LineArt rather than an UnknownArt, but usually this is acceptable. A further refinement might also make use of the string length:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
  size_t len;
  if(const char* s = lua_tolstring(L, -1, &len))
  {
    switch(len)
    {
    case 7:
      if(!memcmp(s, "Graphic", 7)) return new GraphicArt(L);
      break;
    case 9:
      if(!memcmp(s, "Rectangle", 9)) return new RectangleArt(L);
      break;
    case 4:
      if(!memcmp(s, "Text", 4)) return new TextArt(L);
      if(!memcmp(s, "Line", 4)) return new LineArt(L);
      break;
    case 3:
      if(!memcmp(s, "Swf"), 3) return new SwfArt(L);
      break;
    }
  }
  return new UnknownArt(L);
}

This refinement is technically correct, and probably faster due to dispatching based on length and the use of memcmp rather than strcmp. However, compared to the very first (non-functional) code fragment, this is harder to read and harder to maintain. This point is important to bear in mind, but things will get worse still before they get better. It just so happens that in the project I'm working on, the Lua library is statically linked. This means I can take advantage of Lua's implementation details safe in the knowledge that the implementation isn't going to change. The detail which I'd like to take advantage of is that Lua calculates a hash value for each string, and stores this in a header prior to the string contents. It also stores the length in this header, which we can also take advantage of. This train of thought leads to the following code fragment:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
  if(auto ts = reinterpret_cast<const TString*>(lua_tostring(L, -1)))
  {
    switch(ts[-1].tsv.len)
    {
    case 7:
      if(ts[-1].tsv.hash == 1408413413ULL) return new GraphicArt(L);
      break;
    case 9:
      if(ts[-1].tsv.hash == 3769334599ULL) return new RectangleArt(L);
      break;
    case 4:
      switch(ts[-1].tsv.hash)
      {
      case 7903477ULL: return new TextArt(scene, L);
      case 7864041ULL: return new LineArt(scene, L);
      }
      break;
    case 3:
      if(ts[-1].tsv.hash == 205275ULL) return new SwfArt(L);
      break;
    }
  }
  return new UnknownArt(L);
}

As with the previous version based on strcmp, this version isn't perfect. It is highly likely (at least for strings longer than 4 characters) that there will be some other strings of the same length and hash value, but this is acceptable to me. On the positive side, this version should be extremely quick, as switching over a hash should be faster than sequential memcmps. Unfortunately, not only is this version unreadable and unmaintainable, it is also unwritable.

At this point, enter code generation. It is an entirely mechanical process to translate the original (non-functional) code into hash-based dispatch code, and furthermore the hash calculations are best done by a machine. Therefore I've written a preprocessing script for C++ source files which performs this translation. The input looks almost like the original (non-functional) code, with the addition of some markers:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
#ifdef LUA_STRING_SWITCH
  switch(lua_tostring(L, -1))
  {
  case "Graphic":   return new GraphicArt(L);
  case "Rectangle": return new RectangleArt(L);
  case "Text":      return new TextArt(L);
  case "Line":      return new LineArt(L);
  case "Swf":       return new SwfArt(L);
  }
#endif
  return new UnknownArt(L);
}

The script goes through the file, finds each #ifdef LUA_STRING_SWITCH, throws away any existing #else block, and then writes an #else block based on the contents of the switch, resulting in something like:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
#ifdef LUA_STRING_SWITCH
  switch(lua_tostring(L, -1))
  {
  case "Graphic":   return new GraphicArt(L);
  case "Rectangle": return new RectangleArt(L);
  case "Text":      return new TextArt(L);
  case "Line":      return new LineArt(L);
  case "Swf":       return new SwfArt(L);
  }
#else
  if(auto ts = reinterpret_cast<const TString*>(lua_tostring(L, -1)))
  {
    switch(ts[-1].tsv.len)
    {
    case 7:
      if(ts[-1].tsv.hash == 1408413413ULL) return new GraphicArt(L);
      break;
    case 9:
      if(ts[-1].tsv.hash == 3769334599ULL) return new RectangleArt(L);
      break;
    case 4:
      switch(ts[-1].tsv.hash)
      {
      case 7903477ULL: return new TextArt(scene, L);
      case 7864041ULL: return new LineArt(scene, L);
      }
      break;
    case 3:
      if(ts[-1].tsv.hash == 205275ULL) return new SwfArt(L);
      break;
    }
  }
#endif
  return new UnknownArt(L);
}

At compile time, LUA_STRING_SWITCH isn't defined, and so the hash-based dispatch gets used. At edit time, I use the folding features of an IDE to hide the nasty hash-based dispatch, so all I see is:

Art* MakeArt(lua_State* L)
{
  lua_getfield(L, -1, "type");
#ifdef LUA_STRING_SWITCH
  switch(lua_tostring(L, -1))
  {
  case "Graphic":   return new GraphicArt(L);
  case "Rectangle": return new RectangleArt(L);
  case "Text":      return new TextArt(L);
  case "Line":      return new LineArt(L);
  case "Swf":       return new SwfArt(L);
  }
#else /* Hidden Preprocessor Block */
#endif
  return new UnknownArt(L);
}

This gives the best of both worlds. I can write pseudo-C++ which expresses my ideas very clearly, and then it can be mechanically translated and compiled down to something very efficient.

Mapping and Lua Iterators

Introduction to simple iterators in Lua

Let us begin by considering what a simple iterator is. In Lua, a simple iterator is a function. Each time you call the function, you get back the next value in the sequence. In the following example, f is (assumed to be) a simple iterator:

local f = ("cat"):gmatch(".")
print(f()) --> c
print(f()) --> a
print(f()) --> t

[Technical note: In the reference implementation of Lua, string.gmatch returns a simple iterator, but it is questionable whether or not this is mandated by the language. In particular, LuaJIT returns a full iterator rather than a simple iterator; see later]

As Lua functions can return multiple values, it seems logical to allow iterators to produce multiple values per call, and this is indeed allowed. In the following example, f is (assumed to be) a simple iterator which produces two values per call:

local f = ("The Cute Cat"):gmatch("(%u)(%l*)")
print(f()) --> T, he
print(f()) --> C, ute
print(f()) --> C, at

To signal that it has finished producing values, an iterator returns nil. Well, nearly; an iterator can signal completion by returning no values, by returning a single nil, or by returning multiple values with the first value being nil. With this in mind, the previous example can be rewritten to use a loop rather than an explicit number of calls to f:

local f = ("The Cute Cat"):gmatch("(%u)(%l*)")
while true do
  local first, rest = f()
  if first == nil then
    break
  end
  print(first, rest)
end

This is rather verbose syntax, and thankfully the generic for loop can be used to achieve the same effect:

local f = ("The Cute Cat"):gmatch("(%u)(%l*)")
for first, rest in f do
  print(first, rest)
end

In the previous examples, string.gmatch is termed a simple iterator factory, as it returns a simple iterator (f). Equipped with this knowledge of how simple iterators work, we can write a simple iterator factory which emulates the numeric for loop:

function range(init, limit, step)
  step = step or 1
  return function()
    local value = init
    init = init + step
    if limit * step >= value * step then
      return value
    end
  end
end

This can then be used like so:

for n in range(3, 0, -1) do print(n) end --> 3; 2; 1; 0

Introduction to full iterators in Lua

Looking at range in slightly more detail, we see that every time it is called, it returns a function, and that the function has three upvalues (init, limit, and step). From an esoteric efficiency point of view, it would be nice to have less upvalues (or ideally no upvalues at all). Full iterators are one way of achieving this. Whereas a simple iterator is just a function (f), a full iterator is a three-tuple (f,s,v). As with a simple iterator, f is still a function (or something with a __call metamethod). As for the two new bits, s is an opaque value which is passed as the first parameter to each call of f, and v is an opaque value which is passed as the second parameter to the first call of f. Subsequent calls of f pass the first return value from the previous call of f as the the second parameter.

Written out formally like that, full iterators sound a little bit odd, but once you've wrapped your head around it, they are really quite reasonable. As an example, let us rewrite range to be a full iterator factory rather than a simple iterator factory. For this, we'll take s to be limit, and v to be init - step, like so:

function range(init, limit, step)
  step = step or 1
  return function(lim, value)
    value = value + step
    if lim * step >= value * step then
      return value
    end
  end, limit, init - step
end

As the generic for loop is designed to work with full iterators, this version of range can be used just like the previous version:

for n in range(3, 0, -1) do print(n) end --> 3; 2; 1; 0

Alternatively, to appreciate what is going on, we could use a verbose loop instead:

local f, s, v = range(3, 0, -1)
while true do
  v = f(s, v)
  if v == nil then
    break
  end
  print(v)
end

Looking at this version of range, the returned function only has one upvalue (step). In real terms, this means that 32 less bytes of space are allocated on the heap as compared to the three upvalue version. Admittedly this isn't much, and it hardly justifies the language support for full iterators. Rather more convincing are the pairs and ipairs full iterator factories from the Lua standard library. The full iterator machinery is precisely what these two functions need in order for their iterator functions to have no upvalues at all, meaning that you can iterate through a table without having to allocate anything on the heap, as in the following example:

local t = {
  Lemon = "sour",
  Cake = "nice",
}
for ingredient, taste in pairs(t) do
  print(ingredient .." is ".. taste)
end
--> Lemon is sour
--> Cake is nice

Mapping Lua iterators

Switching from Lua to Haskell for just a minute, recall the type signature of the map function:

map :: (a -> b) -> [a] -> [b]

This takes a function as its first parameter, and applies this function to every element of a list in order to generate a new list. Due to Haskell's lazy evaluation behaviour, it isn't unreasonable to equate Haskell lists with iterators in other languages. With that in mind, a first attempt at a map function for Lua iterators might be:

function map(mapf, f, ...)
  return function(...)
    return mapf(f(...))
  end, ...
end

For all intents and purposes, this is just dressed up function composition. For some cases, like the following, it even works:

local f = map(string.upper, ("cat"):gmatch("."))
print(f()) --> C
print(f()) --> A
print(f()) --> T

Unfortunately, the following example doesn't work quite as well:

local t = {
  Lemon = "sour",
  Cake = "nice",
}
for ingredient, taste in map(function(a, b)
  return a:lower(), b:upper()
end, pairs(t)) do
  print(ingredient .." is ".. taste)
end
--> lemon is SOUR
--> invalid key to 'next'

If you've understood how full iterators work, then you should be able to appreciate why this example fails after the first iteration: the map function is changing the results from the iterator function, and the first of these changed results is being passed back to the iterator function, which confuses it. To solve this issue, we need a slightly more complicated construction:

function map(mapf, f, s, v)
  local function domap(...)
    v = ...
    if v ~= nil then
      return mapf(...)
    end
  end
  return function()
    return domap(f(s, v))
  end
end

With this construction, we get the intended behaviour:

local t = {
  Lemon = "sour",
  Cake = "nice",
}
for ingredient, taste in map(function(a, b)
  return a:lower(), b:upper()
end, pairs(t)) do
  print(ingredient .." is ".. taste)
end
--> lemon is SOUR
--> cake is NICE

ConcatMapping Lua iterators

Jumping back to Haskell again, consider the slightly more complex concatMap function:

concatMap :: (a -> [b]) -> [a] -> [b]

This allows the iterator function to produce zero or more results for each value of input, for example:

module Main where
import Data.Char
main = putStr $
  concatMap (\c -> if isUpper c then [c] else [c,c]) "Cat" -- > "Caatt"

We could perform a direct conversion of this to Lua, and have the map function return an iterator, but unlike Haskell's syntactic support for constructing lists, Lua has none for constructing iterators, and so the result would feel rather clunky. The central issue is deciding how to return multiple values, or infact multiple tuples, from the map function. In Haskell it is natural to do this with lists, whereas in Lua we can achieve it with coroutines. Consider the following even more complex form of the map function in Lua:

function map(mapf, f, s, v)
  local done
  local function maybeyield(...)
    if ... ~= nil then
      coroutine.yield(...)
    end
  end
  local function domap(...)
    v = ...
    if v ~= nil then
      return maybeyield(mapf(...))
    else
      done = true
    end
  end
  return coroutine.wrap(function()
    repeat
      domap(f(s, v))
    until done
  end)
end

This version of map still permits the previous example:

local t = {
  Lemon = "sour",
  Cake = "nice",
}
for ingredient, taste in map(function(a, b)
  return a:lower(), b:upper()
end, pairs(t)) do
  print(ingredient .." is ".. taste)
end
--> lemon is SOUR
--> cake is NICE

Furthermore, it supports the generation of multiple tuples for each input by calling coroutine.yield for each additional tuple:

local t = {
  Lemon = "sour",
  Cake = "nice",
}
for ingredient, taste in map(function(a, b)
  coroutine.yield(a:lower(), b:upper())
  return a, "very ".. b
end, pairs(t)) do
  print(ingredient .." is ".. taste)
end
--> lemon is SOUR
--> Lemon is very sour
--> cake is NICE
--> Cake is very nice

Alternatively, if the map function doesn't return, then it doesn't generate anything at all:

for n in map(function(n)
  if n ~= 0 then
    return 1 / n
  end
end, range(-2, 2)) do
  print(n)
end
--> -0.5; -1; 1; 0.5

Algorithmic std::string creation

Suppose we want to create a C++ std::string object, and we know how long the resulting string will be, and we know how we'll generate the content of the string, but we do not have a copy of said content in memory. As a concrete example, consider taking a block of memory, and returning a string whose content is the hexadecimal representation of that memory block. The result should satisfy the following tests:

std::string hexify(const char*, size_t);

assertEqual("baadf00d", hexify("\xBA\xAD\xF0\x0D", 4));
assertEqual("baad"    , hexify("\xBA\xAD\xF0\x0D", 2));

A naïve approach might be something along the following lines:

#include <sstream>
#include <iomanip>

std::string hexifyChar(int c)
{
  std::stringstream ss;
  ss << std::hex << std::setw(2) << std::setfill('0') << c;
  return ss.str();
}

std::string hexify(const char* base, size_t len)
{
  std::stringstream ss;
  for(size_t i = 0; i < len; ++i)
    ss << hexifyChar(base[i]);
  return ss.str();
}

This is fairly readable code, and if speed wasn't a concern, could well be a viable solution. On the other hand, if speed is desired, then this code isn't particularly good, at least not without a clairvoyant optimising compiler. Within both functions, the rather heavyweight std::stringstream tool is used, which may be slow due to be extremely flexible, and also due to not knowing how long its result will be. Given that we do not need most of the flexibility of std::stringstream, and we know how long the result will be (2 for hexifyChar, and 2 * len for hexify), we should expect to be able to do better.

As we want to construct a std::string, let us look at its constructors, and see which ones might work for us:

  1. Default constructor: string ();
  2. Copy constructor: string (const string& str);
  3. Substring constructor: string (const string& str, size_t pos, size_t n = npos);
  4. Existing content constructor: string (const char * s, size_t n);
  5. C string constructor: string (const char * s);
  6. Repetition constructor: string (size_t n, char c);
  7. Iterator constructor: template<class Iterator> string (Iterator begin, Iterator end);

None of these look particularly like what we want, but the last option might be viable if we can wrap a hexifying algorithm inside an iterator. If said iterator is infact a random access iterator, then the constructor will be able to obtain the length of the result before it starts fetching the individual characters.

Writing an STL-compatible iterator usually takes a lot of code, but we can use boost::iterator_facade to do most of the work for us, giving us the following code:

struct HexifyIterator
  : boost::iterator_facade<HexifyIterator, const char,
                           boost::random_access_traversal_tag>
{
  HexifyIterator()
    : ptr_()
    , nibble_() {}
  HexifyIterator(pointer ptr)
    : ptr_(reinterpret_cast<const unsigned char*>(ptr))
    , nibble_(1) {}
  HexifyIterator(const HexifyIterator& other)
    : iterator_facade(other)
    , ptr_(other.ptr_)
    , nibble_(other.nibble_) {}

private:
  friend class boost::iterator_core_access;

  const unsigned char* ptr_;
  difference_type nibble_;

  void increment() { advance(1); }
  void decrement() { advance(-1); }
  void advance(difference_type n)
  {
    nibble_ -= n;
    ptr_ -= nibble_ / 2;
    nibble_ %= 2;
    if(nibble_ < 0)
    {
      ++ptr_;
      nibble_ += 2;
    }
  }

  difference_type distance_to(const HexifyIterator& other) const
  { return (other.ptr_ - ptr_) * 2 + (nibble_ - other.nibble_); }

  bool equal(const HexifyIterator& other) const
  { return ptr_ == other.ptr_ && nibble_ == other.nibble_; }

  reference dereference() const
  { return "0123456789abcdef"[(*ptr_ >> (4 * nibble_)) & 0xF]; }
};

std::string hexify(const char* base, size_t len)
{
  return std::string(HexifyIterator(base), HexifyIterator(base + len));
}

There is still rather more code here than I would like, but in terms of speed, my quick and crude test put it at 72 times faster than the naïve approach.

A look at Lua 5.2 (beta rc2)

Just over a year ago, I looked at Lua 5.2 (work 3). Lua evolves slowly, and 5.2 still isn't quite finished, though perhaps it will be by the time the Lua Workshop 2011 rolls round in September (at which I hope to give a talk). In the mean time, Lua 5.2 (beta rc2) has been released, and this time I've identified the changes since 5.1 by annotating the new reference manual with them: An Annotated Lua 5.2 (beta rc2) Reference Manual.

Presenting the changes inline with the reference manual, as compared to just listing them, provides more context for the changes, as well as making things easier to consume by reducing the density of new information. If you have any good ideas for other ways of presenting this change information, then get in touch.

Callbacks with the LuaJIT FFI

The foreign function interface (FFI) present in the latest beta releases of LuaJIT is really nice for when you need to do things outside the Lua world. At runtime you pass it some C function definitions, and then everything you've defined becomes callable, and these calls are subsequently JIT compiled in the most efficient way possible. For example, we can define and then call some Windows API functions:

ffi = require "ffi"
ffi.cdef [[
  typedef void* HWND;
  HWND FindWindowA(const char* lpClassName, const char* lpWindowName);
  int GetWindowTextA(HWND hWnd, char* lpString, int nMaxCount); ]]
len = 300
buffer = ffi.new("char[?]", len)
window = ffi.C.FindWindowA("Notepad++", nil)
len = ffi.C.GetWindowTextA(window, buffer, len)
print(ffi.string(buffer, len)) --> C:\Lua\ffi_example.lua - Notepad++

This is fine and dandy for calling C from Lua, but things get rather more complicated with callbacks from C back to Lua. For example, the Windows EnumChildWindows function accepts a callback, which gets calls for every child of the given window. LuaJIT will happily accept and understand the definition of this function:

ffi.cdef [[
  typedef void* HWND;
  typedef bool (*WNDENUMPROC)(HWND, long);
  bool EnumChildWindows(HWND hWndParent, WNDENUMPROC lpEnumFunc, long lParam); ]]

You quickly run into a problem if you try to call it though, as you realise that the LuaJIT FFI currently lacks support for turning Lua functions into something which can called from C. At this point, most people would acknowledge that the FFI isn't yet complete, and then go to write their own C glue around EnumChildWindows using the traditional (slow) Lua C API. On the other hand, if you're feeling foolhardy, then you can fight the FFI to get callbacks working, and do so without resorting to any external C code. Naturally, this is what we'll do.

Our strategy will be to perform some control flow contortions so that when EnumChildWindows calls the callback, it infact returns to Lua, then Lua calls back to resume the enumeration. If we could write it in Lua, then it might look something like:

EnumChildWindows = coroutine.wrap(function()
  while true do
    ffi.C.EnumChildWindows(coroutine.yield(), function(hWnd)
      coroutine.yield(hWnd)
    end, nil)
  end
end)

Naturally we cannot write this in Lua, but we can write it in machine code, and we can then use the FFI to load and execute machine code. The coroutine trickery will be done by the Windows fiber API, as fibers are fairly similar to coroutines.

To start with, ConvertThreadToFiber can be called to convert the currently running thread into a fiber and return the handle to the fiber. Though if the thread is already a fiber then we run into a problem, as GetCurrentFiber is a macro rather than a function, and hence is not callable by the FFI. For now we'll ignore this issue, but it will be addressed later. Next we can call VirtualAlloc to allocate some executable memory, use ffi.copy to copy some machine code into said executable memory, then call the equivalent of coroutine.wrap, which is CreateFiber. In code, this looks like:

ffi.cdef [[
  void* ConvertThreadToFiber(void* lpParameter);
  typedef void (*LPFIBER_START_ROUTINE)(void*);
  void* CreateFiber(size_t dwStackSize, LPFIBER_START_ROUTINE lpStartAddress, void* lpParameter);
  void* VirtualAlloc(void* lpAddress, size_t dwSize, uint32_t flAllocationType, uint32_t flProtect); ]]
our_fiber = ffi.C.ConvertThreadToFiber(nil)
machine_code = "TODO"
procs = ffi.C.VirtualAlloc(nil, #machine_code + 1, 0x3000, 0x40)
ffi.copy(procs, machine_code)
contortion_fiber = ffi.C.CreateFiber(1024, ffi.cast("LPFIBER_START_ROUTINE", procs), nil)

The next task is to replace the TODO with the machine code equivalent of the following pseudo-C:

for(;;) {
  EnumChildWindows(coroutine.yield(), EnumerationProcedure, 0);
}

BOOL EnumerationProcedure(HWND hWnd, void* lpParam) {
  coroutine.yield(hWnd);
  return TRUE;
}

First of all we need to make the pseudo-C slightly more C-like. In particular, the above still uses a hypothetical coroutine.yield. The fiber API presents a SwitchToFiber function, which differs from coroutine.yield in that it doesn't support parameters or return values, and it requires telling which fiber to switch to. We thus end up with something like:

void* our_fiber; // The result of ConvertThreadToFiber.
void* transfer_slot[2]; // To yield a value, put the value in [0] and a non-NULL value in [1].
                        // To yield nothing, put anything in [0] and NULL in [1].
for(;;) {
  EnumChildWindows(transfer_slot[0], enum_proc, 0);
  transfer_slot[1] = NULL;
  SwitchToFiber(our_fiber);
}

BOOL EnumerationProcedure(HWND hWnd, void* lpParam) {
  transfer_slot[0] = hWnd;
  SwitchToFiber(our_fiber);
  return TRUE;
}

Next we need to convert this down to assembly code, firstly for x86:

fiber_proc:
push 0
push enum_proc
mov eax, dword ptr [transfer_slot]
push eax
call EnumChildWindows
mov dword ptr [transfer_slot + 4], 0
push our_fiber
call SwitchToFiber
jmp fiber_proc

enum_proc:
mov eax, dword ptr [esp+4]
mov dword ptr [transfer_slot + 4], eax
push our_fiber
call SwitchToFiber
mov eax, 1
retn 8

And secondly for x64:

fiber_proc:
sub rsp, 28h
after_prologue:
mov rcx, qword ptr [rip->transfer_slot]
lea rdx, qword ptr [rip->enum_proc]
call qword ptr [rip->EnumChildWindows]
mov qword ptr [rip->transfer_slot + 8], 0
mov rcx, qword ptr [rip->our_fiber]
call qword ptr [rip->SwitchToFiber]
jmp after_prologue

enum_proc:
sub rsp, 28h
mov qword ptr [rip->transfer_slot], rcx
mov rcx, qword ptr [rip->our_fiber]
call qword ptr [rip->SwitchToFiber]
mov rax, 1
add rsp, 28h
ret

transfer_slot:    dq
                  dq
EnumChildWindows: dq
our_fiber:        dq
SwitchToFiber:    dq

At this point, we return to our earlier problem of GetCurrentFiber being a macro, and note that it boils down to the following assembly code, firstly for x86:

mov eax, dword ptr fs:[10h]
ret

And similarly for x64:

mov rax, qword ptr gs:[20h]
ret

Now we can convert the assembly down to machine code, and put everything together:

local ffi = require "ffi"
-- The definitions we want to use.
ffi.cdef [[
  typedef void* HWND;
  typedef bool (*WNDENUMPROC)(HWND, long);
  bool EnumChildWindows(HWND hWndParent, WNDENUMPROC lpEnumFunc, long lParam);
  int GetWindowTextA(HWND hWnd, char* lpString, int nMaxCount); ]]
-- Extra definitions we need for performing contortions with fibers.
ffi.cdef [[
  void* ConvertThreadToFiber(void* lpParameter);
  void SwitchToFiber(void* lpFiber);
  typedef void (*LPFIBER_START_ROUTINE)(void*);
  void* CreateFiber(size_t dwStackSize, LPFIBER_START_ROUTINE lpStartAddress, void* lpParameter);
  uint32_t GetLastError(void);
  void* VirtualAlloc(void* lpAddress, size_t dwSize, uint32_t flAllocationType, uint32_t flProtect);
  bool RtlAddFunctionTable(void* FunctionTable, uint32_t EntryCount, void* BaseAddress); ]]

local EnumChildWindows
do
  local GetLastError = ffi.C.GetLastError
  local contortion_fiber
  local procs
  local transfer_slot
  local init_callbacks
  if ffi.arch == "x86" then
    init_callbacks = function()
      -- Ensure that the thread is a fiber, converting if required.
      local our_fiber = ffi.C.ConvertThreadToFiber(nil)
      if our_fiber == nil and GetLastError() ~= 1280 then
        error("Unable to convert thread to fiber")
      end
      transfer_slot = ffi.new("void*[2]")
      -- fiber_proc: for(;;) {
      --               EnumChildWindows(transfer_slot[0], enum_proc, 0);
      --               transfer_slot[1] = 0; // to mark end of iteration
      --               SwitchToFiber(our_fiber);
      --             }
      local asm = "\x6A\x00" -- push 0
               .. "\x68????" -- push ????
               .. "\xA1????\x50" -- mov eax, dword ptr [????], push eax
               .. "\xE8????" -- call ????
               .. "\xC7\x05????\x00\x00\x00\x00" -- mov dword ptr [????], 0
               .. "\x68????" -- push ????
               .. "\xE8????" -- call ????
               .. "\xEB\xD8" -- jmp $-40
      -- enum_proc: transfer_slot[0] = *(esp+4); // the HWND
      --            SwitchToFiber(our_fiber);
      --            return TRUE;
               .. "\x8B\x44\x24\x04" -- mov eax, dword ptr [esp+4]
               .. "\x3E\xA3????" -- mov dword ptr [????], eax
               .. "\x68????" -- push ????
               .. "\xE8????" -- call ????
               .. "\x33\xC0\x40" -- mov eax, 1
               .. "\xC2\x08" -- retn 8 (*)
      procs = ffi.C.VirtualAlloc(nil, #asm + 1, 0x3000, 0x40)
      if our_fiber == nil then
        -- GetCurrentFiber()
        ffi.copy(procs, "\x64\xA1\x10\x00\x00\x00\xC3") -- return __readfsdword(0x10)
        our_fiber = ffi.cast("void*(*)(void)", procs)()
      end
      ffi.copy(procs, asm)
      local function fixup(offset, ptr, isrelative)
        local dst = ffi.cast("char*", procs) + offset
        ptr = ffi.cast("char*", ptr)
        if isrelative then
          ptr = ffi.cast("char*", ptr - (dst + 4))
        end
        ffi.cast("char**", dst)[0] = ptr
      end
      fixup( 3, ffi.cast("char*", procs) + 40)
      fixup( 8, transfer_slot)
      fixup(14, ffi.C.EnumChildWindows, true)
      fixup(20, transfer_slot + 1)
      fixup(29, our_fiber)
      fixup(34, ffi.C.SwitchToFiber, true)
      fixup(46, transfer_slot)
      fixup(51, our_fiber)
      fixup(56, ffi.C.SwitchToFiber, true)
      contortion_fiber = ffi.C.CreateFiber(1024, ffi.cast("LPFIBER_START_ROUTINE", procs), nil)
      init_callbacks = function() end
    end
  elseif ffi.arch == "x64" then
    init_callbacks = function()
      -- Ensure that the thread is a fiber, converting if required.
      local our_fiber = ffi.C.ConvertThreadToFiber(nil)
      if our_fiber == nil and GetLastError() ~= 1280 then
        error("Unable to convert thread to fiber")
      end
      -- fiber_proc: for(;;) {
      --               EnumChildWindows(transfer_slot[0], enum_proc, 0);
      --               transfer_slot[1] = 0; // to mark end of iteration
      --               SwitchToFiber(our_fiber);
      --             }
      local asm = "\x48\x83\xEC\x28" -- sub rsp, 28h
               .. "\x48\x8B\x0D\x75\x00\x00\x00" -- mov rcx, [rip->transfer_slot_0]
               .. "\x48\x8D\x15\x26\x00\x00\x00" -- lea rdx, [rip->enum_proc]
               .. "\x48\xFF\x15\x77\x00\x00\x00" -- call [rip->EnumChildWindows]
               .. "\x48\xC7\x05\x64\x00\x00\x00\x00\x00\x00\x00" -- mov [rip->transfer_slot_1], 0
               .. "\x48\x8B\x0D\x6D\x00\x00\x00" -- mov rcx, [rip->our_fiber]
               .. "\x48\xFF\x15\x6E\x00\x00\x00" -- call [rip->SwitchToFiber]
               .. "\xEB\xD9" -- jmp $-48
               .. "\x90\x90\x90\x90" -- pad 8
      -- enum_proc: transfer_slot[0] = rcx; // the HWND
      --            SwitchToFiber(our_fiber);
      --            return TRUE;
               .. "\x48\x83\xEC\x28" -- sub rsp, 28h
               .. "\x48\x89\x0D\x3D\x00\x00\x00" -- mov [rip->transfer_slot_0], rcx
               .. "\x48\x8B\x0D\x4E\x00\x00\x00" -- mov rcx, [rip->our_fiber]
               .. "\x48\xFF\x15\x4F\x00\x00\x00" -- call [rip->SwitchToFiber]
               .. "\x48\xC7\xC0\x01\x00\x00\x00" -- mov rax, 1
               .. "\x48\x83\xC4\x28" -- add rsp, 28h
               .. "\xC3" -- ret
               .. "\x90\x90\x90" -- pad 8
      -- unwind data
               .. "\0\0\0\0\52\0\0\0\120\0\0\0"
               .. "\56\0\0\0\93\0\0\0\120\0\0\0"
               .. "\1\4\1\0\4\66"
               -- pad 8
      -- mutable data
               -- transfer_slot_0
               -- transfer_slot_1
               -- EnumChildWindows
               -- our_fiber
               -- SwitchToFiber
      procs = ffi.C.VirtualAlloc(nil, #asm + 42, 0x103000, 0x40)
      if our_fiber == nil then
        -- GetCurrentFiber()
        ffi.copy(procs, "\x65\x48\x8B\x04\x25\x20\x00\x00\x00\xC3") -- return __readgsqword(0x20)
        our_fiber = ffi.cast("void*(*)(void)", procs)()
      end
      ffi.copy(procs, asm)
      transfer_slot = ffi.cast("void**", ffi.cast("char*", procs) + 128)
      transfer_slot[2] = ffi.cast("void*", ffi.C.EnumChildWindows)
      transfer_slot[3] = ffi.cast("void*", our_fiber)
      transfer_slot[4] = ffi.cast("void*", ffi.C.SwitchToFiber)
      ffi.C.RtlAddFunctionTable(ffi.cast("void*", ffi.cast("char*", procs) + 96), 2, procs)
      contortion_fiber = ffi.C.CreateFiber(1024, ffi.cast("LPFIBER_START_ROUTINE", procs), nil)
      init_callbacks = function() end
    end
  else
    error("Only x86 and x64 are supported")
  end
  EnumChildWindows = function(wnd)
    init_callbacks()
    transfer_slot[0] = wnd
    transfer_slot[1] = ffi.cast("void*", 1)
    local results = {}
    while true do
      ffi.C.SwitchToFiber(contortion_fiber)
      if transfer_slot[1] == nil then
        return results
      else
        results[#results + 1] = transfer_slot[0]
      end
    end
  end
end

With this mass of heavy complicated machinery in place, we can finally perform our original goal of enumerating windows:

local buffer  = ffi.new("char[?]", 300)
for _, window in ipairs(EnumChildWindows(nil)) do
  local len = ffi.C.GetWindowTextA(window, buffer, 300)
  if len ~= 0 then
    print(ffi.string(buffer, len))
  end
end

I freely admit that this solution isn't at all elegant, but it does show that callbacks are possible with the current LuaJIT FFI, without the need of resorting to additional C libraries.

page: 11 12 13 14 15