Llama: Add grammar-based sampling

  • Here's my understanding of how this works (please someone correct me if I'm getting this wrong).

    Language models emit tokens one at a time, starting with the prompt that you give them.

    If you have a conversation with an LLM, effectively you can think of that as you giving it a sequence of tokens, then it generates some, then you generate more and so-on.

    This grammar trick effectively takes advantage of this by giving you much more finely grained control over the tokens. So you can do things like this:

        Give me the address of the
        White House as JSON:
        
        {"street": "
    
    Then the LLM can return:

        1600 Pennsylvania Ave NW"
    
    The moment you see that closing double quote, you take over again and inject:

        ",
        "City": "
    
    It fills in:

        Washington, DC"
    
    And so on.

    But because this is all based on a grammar, you can do way more with it than just JSON.

    I saw a brilliant suggestion relating to this on Twitter a while ago:

    > @OpenAI should add an API argument allowing passing up a deterministic context free grammar.

    > [...]

    > While I think DCFL is what you want here in the short term, the really best thing is passing up a small WASM binary that simply is the sampler.

    > Allow a user to pass up a few KB of WASM binary and give it a few megabytes of RAM to run. Would enable next level LLM superpowers.

    https://twitter.com/grantslatton/status/1637692033115762688

  • I think it should be noted that this enforces grammatical constraints on the model's generated text, but it doesn't do anything to properly align the content. This would be useful if you needed to ensure a server delivered well-formatted JSON, but it I suspect it wont solve a lot of alignment issues with current language generation. For example current iterations of Llama and GPT often do not label markdown code-blocks correctly. Using grammar-based sampling, you could enforce that it labels code blocks but you couldn't enforce correct labeling since this is context-dependent. You also couldn't invent a novel domain-specific language without aligning against that language and expect good output.

  • I am in love with this, I tried my hand at building a Constrained Text Generation Studio (https://github.com/Hellisotherpeople/Constrained-Text-Genera...), and got published at COLING 2022 for my paper on it (https://paperswithcode.com/paper/most-language-models-can-be...), but I always knew that something like this or the related idea enumerated in this paper: https://arxiv.org/abs/2306.03081 was the way to go.

    I will have to think about how I can build grammars that force things like syllable counts or syntactic rules. Current LLMs do very poorly on those kinds of tasks due to the tokenization schemes...

  • I implemented this for PyTorch too at https://github.com/Shopify/torch-grammar. I have a hacked version of text-generation-inference that uses it—happy to share that if it’s useful to anyone.

  • Specifically for multi-choice string enums (essentially dropdowns), I wonder if this would work better if the full (joint/product) probability given the logits is considered when picking the final choice, rather than using a greedy algorithm. This will favor the right choice, as opposed to e.g. one of the choices that contain the most common start token - if a start token are shared among many items in the list.

    Of course the probability needs to be adjusted once a subset of the logits goes to zero so it actually makes sense...

  • This grammar "library" was cited as an example of what the format could look like:.

    https://github.com/antlr/grammars-v4

    There is everything from assembly and C++ to glsl and scripting languages, arithmetic, games, and other weird formats like freedesktop shortcuts, llvm ir or verilog.

  • Can someone ELI5 what's going on here? I'm reasonably familiar with LLMs, but I can't quite grok what Georgi is doing here and why it's so exciting for some.

  • I'm interested in this and I'm going to try incorporating it into something I'm doing. That said, I feel like this could be one of those Bitter Lesson situations where it's not the most effective approach in anything but the very short term: http://www.incompleteideas.net/IncIdeas/BitterLesson.html

  • Can anyone recommend some paper or overview on how "sampling" / "decoding" is done in the e2e neural network age? I know how decoding was done for machine translation and speech recognition back in the HMM times (i.e. https://en.wikipedia.org/wiki/Viterbi_algorithm and https://en.wikipedia.org/wiki/Beam_search). These days I get the impression people just do "greedy" - but I don't really know. Any recommendations for info on that topic?

    Edit: Forgot Viterbi

  • How is this different from Guidance and LMQL?

  • This is great and all.

    But LLM's are usually very good at following grammars. I rarely see LLM generating code that is OOD. Ofc, this is only true for popular language (JSON/Python/Java, etc), I can see how this is handy for more niche and in house DSL.

    You still need quite a lot of prompt engineering to get desired outputs, this just add another layer of output verification IMO. But does it really save much as comparing to get the output then parse and reject the output that doesn't follow the grammar? Might be debateable.

    But great work regardless.

  • So, umm, if you want to walk BNF and emit likely tokens you can do that without any "machine learning" or whatever you want to call it. So what is being added here? Training to tie the prompt to the output?

  • Interesting that the second commentor is Tobias Lütke, CEO of Shopify.

  • Ah finally, this was discussed a lot and is well overdue. Remains to be seen how well the models will adapt to this new constraint, though the demo seems promising.

  • Has anyone tested FreeWilly2 (the new Llama2 fine-tune released today by Stable Foundation) on code generation?

  • Does anyone know Japanese well enough to comment on the output from the Japanese example?

  • Something I’m wondering lately is if you are generating tokens fast enough, is restricting the logits actually worth it computationally? If tokens are cheap enough it might be more efficient to validate/discard them as they come rather than place constraints on how they come out. I don’t know how this one works, but the sampling or renormalizing scheme would cost something too right?

  • Could someone help me with context? I'm OOTL and don't understand what is going on here.