From c03fa87956802d02589bfaf35367ce7cfff8abc9 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Sun, 23 Jul 2023 13:15:40 -0500 Subject: [PATCH 01/38] pyproject.toml: extras list should contain only package list, not versions (#515) Update poetry.lock accordingly. --- poetry.lock | 219 ++++++++++++++++++++++++++++++++++++++----------- pyproject.toml | 7 +- 2 files changed, 175 insertions(+), 51 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9d12966..f68444a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,19 @@ # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +[[package]] +name = "annotated-types" +version = "0.5.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = true +python-versions = ">=3.7" +files = [ + {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, + {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + [[package]] name = "anyio" version = "3.6.2" @@ -373,22 +387,22 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.99.1" +version = "0.100.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = true python-versions = ">=3.7" files = [ - {file = "fastapi-0.99.1-py3-none-any.whl", hash = "sha256:976df7bab51ac7beda9f68c4513b8c4490b5c1135c72aafd0a5ee4023ec5282e"}, - {file = "fastapi-0.99.1.tar.gz", hash = "sha256:ac78f717cd80d657bd183f94d33b9bda84aa376a46a9dab513586b8eef1dc6fc"}, + {file = "fastapi-0.100.0-py3-none-any.whl", hash = "sha256:271662daf986da8fa98dc2b7c7f61c4abdfdccfb4786d79ed8b2878f172c6d5f"}, + {file = "fastapi-0.100.0.tar.gz", hash = "sha256:acb5f941ea8215663283c10018323ba7ea737c571b67fc7e88e9469c7eb1d12e"}, ] [package.dependencies] -pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<3.0.0" starlette = ">=0.27.0,<0.28.0" typing-extensions = ">=4.5.0" [package.extras] -all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "ghp-import" @@ -987,55 +1001,150 @@ files = [ [[package]] name = "pydantic" -version = "1.10.7" -description = "Data validation and settings management using python type hints" +version = "2.0.3" +description = "Data validation using Python type hints" optional = true python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e79e999e539872e903767c417c897e729e015872040e56b96e67968c3b918b2d"}, - {file = "pydantic-1.10.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:01aea3a42c13f2602b7ecbbea484a98169fb568ebd9e247593ea05f01b884b2e"}, - {file = "pydantic-1.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516f1ed9bc2406a0467dd777afc636c7091d71f214d5e413d64fef45174cfc7a"}, - {file = "pydantic-1.10.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae150a63564929c675d7f2303008d88426a0add46efd76c3fc797cd71cb1b46f"}, - {file = "pydantic-1.10.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ecbbc51391248116c0a055899e6c3e7ffbb11fb5e2a4cd6f2d0b93272118a209"}, - {file = "pydantic-1.10.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f4a2b50e2b03d5776e7f21af73e2070e1b5c0d0df255a827e7c632962f8315af"}, - {file = "pydantic-1.10.7-cp310-cp310-win_amd64.whl", hash = "sha256:a7cd2251439988b413cb0a985c4ed82b6c6aac382dbaff53ae03c4b23a70e80a"}, - {file = "pydantic-1.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:68792151e174a4aa9e9fc1b4e653e65a354a2fa0fed169f7b3d09902ad2cb6f1"}, - {file = "pydantic-1.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe2507b8ef209da71b6fb5f4e597b50c5a34b78d7e857c4f8f3115effaef5fe"}, - {file = "pydantic-1.10.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10a86d8c8db68086f1e30a530f7d5f83eb0685e632e411dbbcf2d5c0150e8dcd"}, - {file = "pydantic-1.10.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75ae19d2a3dbb146b6f324031c24f8a3f52ff5d6a9f22f0683694b3afcb16fb"}, - {file = "pydantic-1.10.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:464855a7ff7f2cc2cf537ecc421291b9132aa9c79aef44e917ad711b4a93163b"}, - {file = "pydantic-1.10.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:193924c563fae6ddcb71d3f06fa153866423ac1b793a47936656e806b64e24ca"}, - {file = "pydantic-1.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:b4a849d10f211389502059c33332e91327bc154acc1845f375a99eca3afa802d"}, - {file = "pydantic-1.10.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cc1dde4e50a5fc1336ee0581c1612215bc64ed6d28d2c7c6f25d2fe3e7c3e918"}, - {file = "pydantic-1.10.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0cfe895a504c060e5d36b287ee696e2fdad02d89e0d895f83037245218a87fe"}, - {file = "pydantic-1.10.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:670bb4683ad1e48b0ecb06f0cfe2178dcf74ff27921cdf1606e527d2617a81ee"}, - {file = "pydantic-1.10.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:950ce33857841f9a337ce07ddf46bc84e1c4946d2a3bba18f8280297157a3fd1"}, - {file = "pydantic-1.10.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c15582f9055fbc1bfe50266a19771bbbef33dd28c45e78afbe1996fd70966c2a"}, - {file = "pydantic-1.10.7-cp37-cp37m-win_amd64.whl", hash = "sha256:82dffb306dd20bd5268fd6379bc4bfe75242a9c2b79fec58e1041fbbdb1f7914"}, - {file = "pydantic-1.10.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c7f51861d73e8b9ddcb9916ae7ac39fb52761d9ea0df41128e81e2ba42886cd"}, - {file = "pydantic-1.10.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6434b49c0b03a51021ade5c4daa7d70c98f7a79e95b551201fff682fc1661245"}, - {file = "pydantic-1.10.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64d34ab766fa056df49013bb6e79921a0265204c071984e75a09cbceacbbdd5d"}, - {file = "pydantic-1.10.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:701daea9ffe9d26f97b52f1d157e0d4121644f0fcf80b443248434958fd03dc3"}, - {file = "pydantic-1.10.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf135c46099ff3f919d2150a948ce94b9ce545598ef2c6c7bf55dca98a304b52"}, - {file = "pydantic-1.10.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0f85904f73161817b80781cc150f8b906d521fa11e3cdabae19a581c3606209"}, - {file = "pydantic-1.10.7-cp38-cp38-win_amd64.whl", hash = "sha256:9f6f0fd68d73257ad6685419478c5aece46432f4bdd8d32c7345f1986496171e"}, - {file = "pydantic-1.10.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c230c0d8a322276d6e7b88c3f7ce885f9ed16e0910354510e0bae84d54991143"}, - {file = "pydantic-1.10.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:976cae77ba6a49d80f461fd8bba183ff7ba79f44aa5cfa82f1346b5626542f8e"}, - {file = "pydantic-1.10.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d45fc99d64af9aaf7e308054a0067fdcd87ffe974f2442312372dfa66e1001d"}, - {file = "pydantic-1.10.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2a5ebb48958754d386195fe9e9c5106f11275867051bf017a8059410e9abf1f"}, - {file = "pydantic-1.10.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:abfb7d4a7cd5cc4e1d1887c43503a7c5dd608eadf8bc615413fc498d3e4645cd"}, - {file = "pydantic-1.10.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:80b1fab4deb08a8292d15e43a6edccdffa5377a36a4597bb545b93e79c5ff0a5"}, - {file = "pydantic-1.10.7-cp39-cp39-win_amd64.whl", hash = "sha256:d71e69699498b020ea198468e2480a2f1e7433e32a3a99760058c6520e2bea7e"}, - {file = "pydantic-1.10.7-py3-none-any.whl", hash = "sha256:0cd181f1d0b1d00e2b705f1bf1ac7799a2d938cce3376b8007df62b29be3c2c6"}, - {file = "pydantic-1.10.7.tar.gz", hash = "sha256:cfc83c0678b6ba51b0532bea66860617c4cd4251ecf76e9846fa5a9f3454e97e"}, + {file = "pydantic-2.0.3-py3-none-any.whl", hash = "sha256:614eb3321eb600c81899a88fa9858b008e3c79e0d4f1b49ab1f516b4b0c27cfb"}, + {file = "pydantic-2.0.3.tar.gz", hash = "sha256:94f13e0dcf139a5125e88283fc999788d894e14ed90cf478bcc2ee50bd4fc630"}, ] [package.dependencies] -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +pydantic-core = "2.3.0" +typing-extensions = ">=4.6.1" [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.3.0" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:4542c98b8364b976593703a2dda97377433b102f380b61bc3a2cbc2fbdae1d1f"}, + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9342de50824b40f55d2600f66c6f9a91a3a24851eca39145a749a3dc804ee599"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:539432f911686cb80284c30b33eaf9f4fd9a11e1111fe0dc98fdbdce69b49821"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38a0e7ee65c8999394d92d9c724434cb629279d19844f2b69d9bbc46dc8b8b61"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:e3ed6834cc005798187a56c248a2240207cb8ffdda1c89e9afda4c3d526c2ea0"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:e72ac299a6bf732a60852d052acf3999d234686755a02ba111e85e7ebf8155b1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:616b3451b05ca63b8f433c627f68046b39543faeaa4e50d8c6699a2a1e4b85a5"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:adcb9c8848e15c613e483e0b99767ae325af27fe0dbd866df01fe5849d06e6e1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:464bf799b422be662e5e562e62beeffc9eaa907d381a9d63a2556615bbda286d"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4638ebc17de08c2f3acba557efeb6f195c88b7299d8c55c0bb4e20638bbd4d03"}, + {file = "pydantic_core-2.3.0-cp310-none-win32.whl", hash = "sha256:9ff322c7e1030543d35d83bb521b69114d3d150750528d7757544f639def9ad6"}, + {file = "pydantic_core-2.3.0-cp310-none-win_amd64.whl", hash = "sha256:4824eb018f0a4680b1e434697a9bf3f41c7799b80076d06530cbbd212e040ccc"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:0aa429578e23885b3984c49d687cd05ab06f0b908ea1711a8bf7e503b7f97160"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20d710c1f79af930b8891bcebd84096798e4387ab64023ef41521d58f21277d3"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:309f45d4d7481d6f09cb9e35c72caa0e50add4a30bb08c04c5fe5956a0158633"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bcfb7be905aa849bd882262e1df3f75b564e2f708b4b4c7ad2d3deaf5410562"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:85cd9c0af34e371390e3cb2f3a470b0b40cc07568c1e966c638c49062be6352d"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:37c5028cebdf731298724070838fb3a71ef1fbd201d193d311ac2cbdbca25a23"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:e4208f23f12d0ad206a07a489ef4cb15722c10b62774c4460ee4123250be938e"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c24465dd11b65c8510f251b095fc788c7c91481c81840112fe3f76c30793a455"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3cd7ee8bbfab277ab56e272221886fd33a1b5943fbf45ae9195aa6a48715a8a0"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0fc7e0b056b66cc536e97ef60f48b3b289f6b3b62ac225afd4b22a42434617bf"}, + {file = "pydantic_core-2.3.0-cp311-none-win32.whl", hash = "sha256:4788135db4bd83a5edc3522b11544b013be7d25b74b155e08dd3b20cd6663bbb"}, + {file = "pydantic_core-2.3.0-cp311-none-win_amd64.whl", hash = "sha256:f93c867e5e85584a28c6a6feb6f2086d717266eb5d1210d096dd717b7f4dec04"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:73f62bb7fd862d9bcd886e10612bade6fe042eda8b47e8c129892bcfb7b45e84"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d889d498fce64bfcd8adf1a78579a7f626f825cbeb2956a24a29b35f9a1df32"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d55e38a89ec2ae17b2fa7ffeda6b70f63afab1888bd0d57aaa7b7879760acb4"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aefebb506bc1fe355d91d25f12bcdea7f4d7c2d9f0f6716dd025543777c99a5"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_armv7l.whl", hash = "sha256:6441a29f42585f085db0c04cd0557d4cbbb46fa68a0972409b1cfe9f430280c1"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_ppc64le.whl", hash = "sha256:47e8f034be31390a8f525431eb5e803a78ce7e2e11b32abf5361a972e14e6b61"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_s390x.whl", hash = "sha256:ad814864aba263be9c83ada44a95f72d10caabbf91589321f95c29c902bdcff0"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9eff3837d447fccf2ac38c259b14ab9cbde700df355a45a1f3ff244d5e78f8b6"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:534f3f63c000f08050c6f7f4378bf2b52d7ba9214e9d35e3f60f7ad24a4d6425"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ef6a222d54f742c24f6b143aab088702db3a827b224e75b9dd28b38597c595fe"}, + {file = "pydantic_core-2.3.0-cp312-none-win32.whl", hash = "sha256:4e26944e64ecc1d7b19db954c0f7b471f3b141ec8e1a9f57cfe27671525cd248"}, + {file = "pydantic_core-2.3.0-cp312-none-win_amd64.whl", hash = "sha256:019c5c41941438570dfc7d3f0ae389b2425add1775a357ce1e83ed1434f943d6"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:27c1bbfb9d84a75cf33b7f19b53c29eb7ead99b235fce52aced5507174ab8f98"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:7cb496e934b71f1ade844ab91d6ccac78a3520e5df02fdb2357f85a71e541e69"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5af2d43b1978958d91351afbcc9b4d0cfe144c46c61740e82aaac8bb39ab1a4d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3097c39d7d4e8dba2ef86de171dcccad876c36d8379415ba18a5a4d0533510"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:dd3b023f3317dbbbc775e43651ce1a31a9cea46216ad0b5be37afc18a2007699"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:27babb9879bf2c45ed655d02639f4c30e2b9ef1b71ce59c2305bbf7287910a18"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:2183a9e18cdc0de53bdaa1675f237259162abeb62d6ac9e527c359c1074dc55d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c089d8e7f1b4db08b2f8e4107304eec338df046275dad432635a9be9531e2fc8"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f10aa5452b865818dd0137f568d443f5e93b60a27080a01aa4b7512c7ba13a3"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f642313d559f9d9a00c4de6820124059cc3342a0d0127b18301de2c680d5ea40"}, + {file = "pydantic_core-2.3.0-cp37-none-win32.whl", hash = "sha256:45327fc57afbe3f2c3d7f54a335d5cecee8a9fdb3906a2fbed8af4092f4926df"}, + {file = "pydantic_core-2.3.0-cp37-none-win_amd64.whl", hash = "sha256:e427b66596a6441a5607dfc0085b47d36073f88da7ac48afd284263b9b99e6ce"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:0b3d781c71b8bfb621ef23b9c874933e2cd33237c1a65cc20eeb37437f8e7e18"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad46027dbd5c1db87dc0b49becbe23093b143a20302028d387dae37ee5ef95f5"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39aa09ed7ce2a648c904f79032d16dda29e6913112af8465a7bf710eef23c7ca"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05b4bf8c58409586a7a04c858a86ab10f28c6c1a7c33da65e0326c59d5b0ab16"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:ba2b807d2b62c446120906b8580cddae1d76d3de4efbb95ccc87f5e35c75b4b2"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:ea955e4ed21f4bbb9b83fea09fc6af0bed82e69ecf6b35ec89237a0a49633033"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:06884c07956526ac9ebfef40fe21a11605569b8fc0e2054a375fb39c978bf48f"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f868e731a18b403b88aa434d960489ceeed0ddeb44ebc02389540731a67705e0"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cb08fab0fc1db15c277b72e33ac74ad9c0c789413da8984a3eacb22a94b42ef4"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6ca34c29fbd6592de5fd39e80c1993634d704c4e7e14ba54c87b2c7c53da68fe"}, + {file = "pydantic_core-2.3.0-cp38-none-win32.whl", hash = "sha256:cd782807d35c8a41aaa7d30b5107784420eefd9fdc1c760d86007d43ae00b15d"}, + {file = "pydantic_core-2.3.0-cp38-none-win_amd64.whl", hash = "sha256:01f56d5ee70b1d39c0fd08372cc5142274070ab7181d17c86035f130eebc05b8"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:78b1ac0151271ce62bc2b33755f1043eda6a310373143a2f27e2bcd3d5fc8633"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64bfd2c35a2c350f73ac52dc134d8775f93359c4c969280a6fe5301b5b6e7431"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:937c0fe9538f1212b62df6a68f8d78df3572fe3682d9a0dd8851eac8a4e46063"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d965c7c4b40d1cedec9188782e98bd576f9a04868835604200c3a6e817b824f"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:ad442b8585ed4a3c2d22e4bf7b465d9b7d281e055b09719a8aeb5b576422dc9b"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:4bf20c9722821fce766e685718e739deeccc60d6bc7be5029281db41f999ee0c"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:f3dd5333049b5b3faa739e0f40b77cc8b7a1aded2f2da0e28794c81586d7b08a"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dc5f516b24d24bc9e8dd9305460899f38302b3c4f9752663b396ef9848557bf"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:055f7ea6b1fbb37880d66d70eefd22dd319b09c79d2cb99b1dbfeb34b653b0b2"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:af693a89db6d6ac97dd84dd7769b3f2bd9007b578127d0e7dda03053f4d3b34b"}, + {file = "pydantic_core-2.3.0-cp39-none-win32.whl", hash = "sha256:f60e31e3e15e8c294bf70c60f8ae4d0c3caf3af8f26466e9aa8ea4c01302749b"}, + {file = "pydantic_core-2.3.0-cp39-none-win_amd64.whl", hash = "sha256:2b79f3681481f4424d7845cc7a261d5a4baa810d656b631fa844dc9967b36a7b"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:a666134b41712e30a71afaa26deeb4da374179f769fa49784cdf0e7698880fab"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c119e9227487ad3d7c3c737d896afe548a6be554091f9745da1f4b489c40561"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73929a2fb600a2333fce2efd92596cff5e6bf8946e20e93c067b220760064862"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:41bbc2678a5b6a19371b2cb51f30ccea71f0c14b26477d2d884fed761cea42c7"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dcbff997f47d45bf028bda4c3036bb3101e89a3df271281d392b6175f71c71d1"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:afa8808159169368b66e4fbeafac6c6fd8f26246dc4d0dcc2caf94bd9cf1b828"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:12be3b5f54f8111ca38e6b7277f26c23ba5cb3344fae06f879a0a93dfc8b479e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ed5babdcd3d052ba5cf8832561f18df20778c7ccf12587b2d82f7bf3bf259a0e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d642e5c029e2acfacf6aa0a7a3e822086b3b777c70d364742561f9ca64c1ffc"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba3073eb38a1294e8c7902989fb80a7a147a69db2396818722bd078476586a0"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5146a6749b1905e04e62e0ad4622f079e5582f8b3abef5fb64516c623127908"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:deeb64335f489c3c11949cbd1d1668b3f1fb2d1c6a5bf40e126ef7bf95f9fa40"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:31acc37288b8e69e4849f618c3d5cf13b58077c1a1ff9ade0b3065ba974cd385"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e09d9f6d722de9d4c1c5f122ea9bc6b25a05f975457805af4dcab7b0128aacbf"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ba6a8cf089222a171b8f84e6ec2d10f7a9d14f26be3a347b14775a8741810676"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef1fd1b24e9bcddcb168437686677104e205c8e25b066e73ffdf331d3bb8792b"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eda1a89c4526826c0a87d33596a4cd15b8f58e9250f503e39af1699ba9c878e8"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3e9a18401a28db4358da2e191508702dbf065f2664c710708cdf9552b9fa50c"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a439fd0d45d51245bbde799726adda5bd18aed3fa2b01ab2e6a64d6d13776fa3"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:bf6a1d2c920cc9528e884850a4b2ee7629e3d362d5c44c66526d4097bbb07a1a"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e33fcbea3b63a339dd94de0fc442fefacfe681cc7027ce63f67af9f7ceec7422"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:bf3ed993bdf4754909f175ff348cf8f78d4451215b8aa338633f149ca3b1f37a"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7584171eb3115acd4aba699bc836634783f5bd5aab131e88d8eeb8a3328a4a72"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1624baa76d1740711b2048f302ae9a6d73d277c55a8c3e88b53b773ebf73a971"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:06f33f695527f5a86e090f208978f9fd252c9cfc7e869d3b679bd71f7cb2c1fa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7ecf0a67b212900e92f328181fed02840d74ed39553cdb38d27314e2b9c89dfa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:45fa1e8ad6f4367ad73674ca560da8e827cc890eaf371f3ee063d6d7366a207b"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8d0dbcc57839831ae79fd24b1b83d42bc9448d79feaf3ed3fb5cbf94ffbf3eb7"}, + {file = "pydantic_core-2.3.0.tar.gz", hash = "sha256:5cfb5ac4e82c47d5dc25b209dd4c3989e284b80109f9e08b33c895080c424b4f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pydantic-settings" +version = "2.0.2" +description = "Settings management using Pydantic" +optional = true +python-versions = ">=3.7" +files = [ + {file = "pydantic_settings-2.0.2-py3-none-any.whl", hash = "sha256:6183a2abeab465d5a3ab69758e9a22d38b0cc2ba193f0b85f6971a252ea630f6"}, + {file = "pydantic_settings-2.0.2.tar.gz", hash = "sha256:342337fff50b23585e807a86dec85037900972364435c55c2fc00d16ff080539"}, +] + +[package.dependencies] +pydantic = ">=2.0.1" +python-dotenv = ">=0.21.0" [[package]] name = "pygments" @@ -1102,6 +1211,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.0" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = true +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, + {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pywin32-ctypes" version = "0.2.0" @@ -1628,9 +1751,9 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -server = ["fastapi", "sse-starlette", "uvicorn"] +server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "da42c48a426b64ce393b4febca1be0e2ea0fe9d48cedb2392b390d4a49276474" +content-hash = "6290d1ac980de004aa71b1f0521af7ffff236c6c647c772b25c130bc3f0b2bc6" diff --git a/pyproject.toml b/pyproject.toml index 02273b9..8baefbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,9 @@ typing-extensions = "^4.7.1" numpy = "^1.24.4" diskcache = "^5.6.1" uvicorn = { version = "^0.22.0", optional = true } -fastapi = { version = "^0.99.1", optional = true } -sse-starlette = { version = "^1.6.1", optional = true } +fastapi = { version = ">=0.100.0", optional = true } +sse-starlette = { version = ">=1.6.1", optional = true } +pydantic-settings = { version = ">=2.0.1", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -32,7 +33,7 @@ httpx = "^0.24.1" scikit-build = "0.17.6" [tool.poetry.extras] -server = ["uvicorn>=0.22.0", "fastapi>=0.100.0", "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1"] +server = ["uvicorn", "fastapi", "pydantic-settings", "sse-starlette"] [build-system] requires = [ From 426dbfe3f4518114a9a8d8ceb80146c89e56aee3 Mon Sep 17 00:00:00 2001 From: Shouyi Wang Date: Tue, 25 Jul 2023 18:29:59 +1000 Subject: [PATCH 02/38] Change tensor_split from array to pointer --- llama_cpp/llama.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9679b2e..66c76c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -273,13 +273,12 @@ class Llama: self.params.low_vram = low_vram self.tensor_split = tensor_split - self._c_tensor_split = None + self._p_tensor_split = None if self.tensor_split is not None: - #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES - FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value - self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd - self.params.tensor_split = self._c_tensor_split + FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split) + self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd + self.params.tensor_split = self._p_tensor_split self.params.rope_freq_base = rope_freq_base self.params.rope_freq_scale = rope_freq_scale From 0687a3092b50a30fc02d43f7d32644da5444efef Mon Sep 17 00:00:00 2001 From: Ihsan Soydemir Date: Tue, 25 Jul 2023 20:49:44 +0200 Subject: [PATCH 03/38] Fix typo in 70B path --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0e6f218..ea1e07f 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ llm = Llama(model_path="./models/7B/ggml-model.bin", n_ctx=2048) Llama2 70b must set the `n_gqa` parameter (grouped-query attention factor) to 8 when loading: ```python -llm = Llama(model_path="./models/7B/ggml-model.bin", n_gqa=8) +llm = Llama(model_path="./models/70B/ggml-model.bin", n_gqa=8) ``` ## Web Server From bfbbc8db149fa7d27b8e0e62d2efece0962cd00e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:54:39 +0000 Subject: [PATCH 04/38] Bump mkdocs-material from 9.1.18 to 9.1.19 Bumps [mkdocs-material](https://github.com/squidfunk/mkdocs-material) from 9.1.18 to 9.1.19. - [Release notes](https://github.com/squidfunk/mkdocs-material/releases) - [Changelog](https://github.com/squidfunk/mkdocs-material/blob/master/CHANGELOG) - [Commits](https://github.com/squidfunk/mkdocs-material/compare/9.1.18...9.1.19) --- updated-dependencies: - dependency-name: mkdocs-material dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index f68444a..8f6614b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -790,13 +790,13 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.1.18" +version = "9.1.19" description = "Documentation that simply works" optional = false python-versions = ">=3.7" files = [ - {file = "mkdocs_material-9.1.18-py3-none-any.whl", hash = "sha256:5bcf8fb79ac2f253c0ffe93fa181cba87718c6438f459dc4180ac7418cc9a450"}, - {file = "mkdocs_material-9.1.18.tar.gz", hash = "sha256:981dd39979723d4cda7cfc77bbbe5e54922d5761a7af23fb8ba9edb52f114b13"}, + {file = "mkdocs_material-9.1.19-py3-none-any.whl", hash = "sha256:fb0a149294b319aedf36983919d8c40c9e566db21ead16258e20ebd2e6c0961c"}, + {file = "mkdocs_material-9.1.19.tar.gz", hash = "sha256:73b94b08c765e92a80645aac58d6a741fc5f587deec2b715489c714827b15a6f"}, ] [package.dependencies] @@ -1756,4 +1756,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "6290d1ac980de004aa71b1f0521af7ffff236c6c647c772b25c130bc3f0b2bc6" +content-hash = "0b968f93c1075722aac40acb68ff9cfdb3374b2a213cac623bf498041e7bf277" diff --git a/pyproject.toml b/pyproject.toml index 06734c2..560d244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ black = "^23.3.0" twine = "^4.0.2" mkdocs = "^1.4.3" mkdocstrings = {extras = ["python"], version = "^0.22.0"} -mkdocs-material = "^9.1.18" +mkdocs-material = "^9.1.19" pytest = "^7.4.0" httpx = "^0.24.1" scikit-build = "0.17.6" From ecdfe4fbd3d26273c37bd8a27dc307c5a3a97bab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:08:28 +0000 Subject: [PATCH 05/38] Bump black from 23.3.0 to 23.7.0 Bumps [black](https://github.com/psf/black) from 23.3.0 to 23.7.0. - [Release notes](https://github.com/psf/black/releases) - [Changelog](https://github.com/psf/black/blob/main/CHANGES.md) - [Commits](https://github.com/psf/black/compare/23.3.0...23.7.0) --- updated-dependencies: - dependency-name: black dependency-type: direct:development update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- poetry.lock | 53 ++++++++++++++++++++++++-------------------------- pyproject.toml | 2 +- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8f6614b..9a8a39a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -36,36 +36,33 @@ trio = ["trio (>=0.16,<0.22)"] [[package]] name = "black" -version = "23.3.0" +version = "23.7.0" description = "The uncompromising code formatter." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, - {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, - {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, - {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, - {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, - {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, - {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, - {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, - {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, - {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, - {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, - {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, - {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, - {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, ] [package.dependencies] @@ -1756,4 +1753,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "0b968f93c1075722aac40acb68ff9cfdb3374b2a213cac623bf498041e7bf277" +content-hash = "6383ea27faa7fa2602fc33b0a93f72f87dd6f274fed8354206246c127ff458bb" diff --git a/pyproject.toml b/pyproject.toml index 560d244..e7c6b21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ sse-starlette = { version = ">=1.6.1", optional = true } pydantic-settings = { version = ">=2.0.1", optional = true } [tool.poetry.group.dev.dependencies] -black = "^23.3.0" +black = "^23.7.0" twine = "^4.0.2" mkdocs = "^1.4.3" mkdocstrings = {extras = ["python"], version = "^0.22.0"} From 583d63351acbcc0c034bb3b5da1cefa2d84c93dc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 27 Jul 2023 03:33:07 +0000 Subject: [PATCH 06/38] Bump uvicorn from 0.22.0 to 0.23.1 Bumps [uvicorn](https://github.com/encode/uvicorn) from 0.22.0 to 0.23.1. - [Release notes](https://github.com/encode/uvicorn/releases) - [Changelog](https://github.com/encode/uvicorn/blob/master/CHANGELOG.md) - [Commits](https://github.com/encode/uvicorn/compare/0.22.0...0.23.1) --- updated-dependencies: - dependency-name: uvicorn dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- poetry.lock | 11 ++++++----- pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9a8a39a..8fad112 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1652,18 +1652,19 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.22.0" +version = "0.23.1" description = "The lightning-fast ASGI server." optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "uvicorn-0.22.0-py3-none-any.whl", hash = "sha256:e9434d3bbf05f310e762147f769c9f21235ee118ba2d2bf1155a7196448bd996"}, - {file = "uvicorn-0.22.0.tar.gz", hash = "sha256:79277ae03db57ce7d9aa0567830bbb51d7a612f54d6e1e3e92da3ef24c2c8ed8"}, + {file = "uvicorn-0.23.1-py3-none-any.whl", hash = "sha256:1d55d46b83ee4ce82b4e82f621f2050adb3eb7b5481c13f9af1744951cae2f1f"}, + {file = "uvicorn-0.23.1.tar.gz", hash = "sha256:da9b0c8443b2d7ee9db00a345f1eee6db7317432c9d4400f5049cc8d358383be"}, ] [package.dependencies] click = ">=7.0" h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] @@ -1753,4 +1754,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "6383ea27faa7fa2602fc33b0a93f72f87dd6f274fed8354206246c127ff458bb" +content-hash = "95adf05a0934d122dd601835c2d6353cc1dda03e4e8a5c5af02bfd1369afa74a" diff --git a/pyproject.toml b/pyproject.toml index e7c6b21..2ac020a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ python = "^3.8.1" typing-extensions = "^4.7.1" numpy = "^1.24.4" diskcache = "^5.6.1" -uvicorn = { version = "^0.22.0", optional = true } +uvicorn = { version = "^0.23.1", optional = true } fastapi = { version = ">=0.100.0", optional = true } sse-starlette = { version = ">=1.6.1", optional = true } pydantic-settings = { version = ">=2.0.1", optional = true } From 3e77eea7ec4d5f2b98ec91248f3d85ea9489aa98 Mon Sep 17 00:00:00 2001 From: Ihsan Soydemir Date: Thu, 27 Jul 2023 19:44:15 +0200 Subject: [PATCH 07/38] Fix OpenBLAS Docker build Current build produces the following: `RuntimeError: Failed to load shared library '/usr/local/lib/python3.11/site-packages/llama_cpp/libllama.so': /usr/local/lib/python3.11/site-packages/llama_cpp/libllama.so: undefined symbol: cblas_sgemm` --- docker/openblas_simple/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/openblas_simple/Dockerfile b/docker/openblas_simple/Dockerfile index 8231bdb..020c34d 100644 --- a/docker/openblas_simple/Dockerfile +++ b/docker/openblas_simple/Dockerfile @@ -9,7 +9,7 @@ COPY . . RUN apt update && apt install -y libopenblas-dev ninja-build build-essential RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings -RUN LLAMA_OPENBLAS=1 pip install llama_cpp_python --verbose +RUN CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama_cpp_python --verbose # Run the server CMD python3 -m llama_cpp.server From abc538fcd55a79293f68bc46b8d078ee7b88bc66 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jul 2023 01:43:00 -0400 Subject: [PATCH 08/38] fix: annoying bug where attribute exceptions were droining out file not found exceptions --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 66c76c9..b52a398 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1503,10 +1503,10 @@ class Llama: return self._convert_text_completion_to_chat(completion) def __del__(self): - if self.model is not None: + if hasattr(self, "model") and self.model is not None: llama_cpp.llama_free_model(self.model) self.model = None - if self.ctx is not None: + if hasattr(self, "ctx") and self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None From a9b9f0397cd86509b3ea359e5260e329464dc032 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jul 2023 01:53:08 -0400 Subject: [PATCH 09/38] Format --- llama_cpp/llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b52a398..2537af2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -224,7 +224,7 @@ class Llama: rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b - rms_norm_eps: Optional[float] = None, # (TEMPORARY) + rms_norm_eps: Optional[float] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -277,7 +277,9 @@ class Llama: if self.tensor_split is not None: FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split) - self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd + self._p_tensor_split = ctypes.POINTER(ctypes.c_float)( + FloatArray + ) # keep a reference to the array so it is not gc'd self.params.tensor_split = self._p_tensor_split self.params.rope_freq_base = rope_freq_base @@ -959,9 +961,7 @@ class Llama: for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position >= ( - remaining_length - first_stop_position - ): + if token_end_position >= (remaining_length - first_stop_position): break logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: From ce57920e608d075335dbd291476420f2abc491be Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jul 2023 14:45:18 -0400 Subject: [PATCH 10/38] Suppress llama.cpp output when loading model. --- llama_cpp/llama.py | 23 +++++++++++++++++++---- llama_cpp/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 llama_cpp/utils.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2537af2..47f71e9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -27,6 +27,8 @@ from .llama_types import * import numpy as np import numpy.typing as npt +from .utils import suppress_stdout_stderr + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -308,12 +310,25 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.params - ) + if verbose: + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) + else: + with suppress_stdout_stderr(): + self.model = llama_cpp.llama_load_model_from_file( + self.model_path.encode("utf-8"), self.params + ) assert self.model is not None - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + if verbose: + self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) + else: + with suppress_stdout_stderr(): + print("here") + self.ctx = llama_cpp.llama_new_context_with_model( + self.model, self.params + ) assert self.ctx is not None diff --git a/llama_cpp/utils.py b/llama_cpp/utils.py new file mode 100644 index 0000000..c14f53f --- /dev/null +++ b/llama_cpp/utils.py @@ -0,0 +1,38 @@ +import os +import sys + + +class suppress_stdout_stderr(object): + # Oddly enough this works better than the contextlib version + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() From 4a2f41a80396bbefc9aec70ff47dc2ded28ac716 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Jul 2023 20:31:16 +0000 Subject: [PATCH 11/38] Bump fastapi from 0.100.0 to 0.100.1 Bumps [fastapi](https://github.com/tiangolo/fastapi) from 0.100.0 to 0.100.1. - [Release notes](https://github.com/tiangolo/fastapi/releases) - [Commits](https://github.com/tiangolo/fastapi/compare/0.100.0...0.100.1) --- updated-dependencies: - dependency-name: fastapi dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8fad112..95c18ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -384,13 +384,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.100.0" +version = "0.100.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = true python-versions = ">=3.7" files = [ - {file = "fastapi-0.100.0-py3-none-any.whl", hash = "sha256:271662daf986da8fa98dc2b7c7f61c4abdfdccfb4786d79ed8b2878f172c6d5f"}, - {file = "fastapi-0.100.0.tar.gz", hash = "sha256:acb5f941ea8215663283c10018323ba7ea737c571b67fc7e88e9469c7eb1d12e"}, + {file = "fastapi-0.100.1-py3-none-any.whl", hash = "sha256:ec6dd52bfc4eff3063cfcd0713b43c87640fefb2687bbbe3d8a08d94049cdf32"}, + {file = "fastapi-0.100.1.tar.gz", hash = "sha256:522700d7a469e4a973d92321ab93312448fbe20fca9c8da97effc7e7bc56df23"}, ] [package.dependencies] From ecb72cc0a258da4220b28842edd68107f25c19b7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Jul 2023 20:32:04 +0000 Subject: [PATCH 12/38] Bump mkdocs-material from 9.1.19 to 9.1.21 Bumps [mkdocs-material](https://github.com/squidfunk/mkdocs-material) from 9.1.19 to 9.1.21. - [Release notes](https://github.com/squidfunk/mkdocs-material/releases) - [Changelog](https://github.com/squidfunk/mkdocs-material/blob/master/CHANGELOG) - [Commits](https://github.com/squidfunk/mkdocs-material/compare/9.1.19...9.1.21) --- updated-dependencies: - dependency-name: mkdocs-material dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 23 +++++++++++++---------- pyproject.toml | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8fad112..c36ffed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -744,13 +744,13 @@ files = [ [[package]] name = "mkdocs" -version = "1.4.3" +version = "1.5.1" description = "Project documentation with Markdown." optional = false python-versions = ">=3.7" files = [ - {file = "mkdocs-1.4.3-py3-none-any.whl", hash = "sha256:6ee46d309bda331aac915cd24aab882c179a933bd9e77b80ce7d2eaaa3f689dd"}, - {file = "mkdocs-1.4.3.tar.gz", hash = "sha256:5955093bbd4dd2e9403c5afaf57324ad8b04f16886512a3ee6ef828956481c57"}, + {file = "mkdocs-1.5.1-py3-none-any.whl", hash = "sha256:67e889f8d8ba1fe5decdfc59f5f8f21d6a8925a129339e93dede303bdea03a98"}, + {file = "mkdocs-1.5.1.tar.gz", hash = "sha256:f2f323c62fffdf1b71b84849e39aef56d6852b3f0a5571552bca32cefc650209"}, ] [package.dependencies] @@ -759,16 +759,19 @@ colorama = {version = ">=0.4", markers = "platform_system == \"Windows\""} ghp-import = ">=1.0" importlib-metadata = {version = ">=4.3", markers = "python_version < \"3.10\""} jinja2 = ">=2.11.1" -markdown = ">=3.2.1,<3.4" +markdown = ">=3.2.1" +markupsafe = ">=2.0.1" mergedeep = ">=1.3.4" packaging = ">=20.5" +pathspec = ">=0.11.1" +platformdirs = ">=2.2.0" pyyaml = ">=5.1" pyyaml-env-tag = ">=0.1" watchdog = ">=2.0" [package.extras] i18n = ["babel (>=2.9.0)"] -min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-import (==1.0)", "importlib-metadata (==4.3)", "jinja2 (==2.11.1)", "markdown (==3.2.1)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "packaging (==20.5)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "typing-extensions (==3.10)", "watchdog (==2.0)"] +min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-import (==1.0)", "importlib-metadata (==4.3)", "jinja2 (==2.11.1)", "markdown (==3.2.1)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "packaging (==20.5)", "pathspec (==0.11.1)", "platformdirs (==2.2.0)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "typing-extensions (==3.10)", "watchdog (==2.0)"] [[package]] name = "mkdocs-autorefs" @@ -787,20 +790,20 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.1.19" +version = "9.1.21" description = "Documentation that simply works" optional = false python-versions = ">=3.7" files = [ - {file = "mkdocs_material-9.1.19-py3-none-any.whl", hash = "sha256:fb0a149294b319aedf36983919d8c40c9e566db21ead16258e20ebd2e6c0961c"}, - {file = "mkdocs_material-9.1.19.tar.gz", hash = "sha256:73b94b08c765e92a80645aac58d6a741fc5f587deec2b715489c714827b15a6f"}, + {file = "mkdocs_material-9.1.21-py3-none-any.whl", hash = "sha256:58bb2f11ef240632e176d6f0f7d1cff06be1d11c696a5a1b553b808b4280ed47"}, + {file = "mkdocs_material-9.1.21.tar.gz", hash = "sha256:71940cdfca84ab296b6362889c25395b1621273fb16c93deda257adb7ff44ec8"}, ] [package.dependencies] colorama = ">=0.4" jinja2 = ">=3.0" markdown = ">=3.2" -mkdocs = ">=1.4.2" +mkdocs = ">=1.5.0" mkdocs-material-extensions = ">=1.1" pygments = ">=2.14" pymdown-extensions = ">=9.9.1" @@ -1754,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "95adf05a0934d122dd601835c2d6353cc1dda03e4e8a5c5af02bfd1369afa74a" +content-hash = "00bcb182a7f4e32ac8e7f6559f37e8a06fb911bac4b8556b8cfdc9201c945d94" diff --git a/pyproject.toml b/pyproject.toml index 2ac020a..9cf9710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ black = "^23.7.0" twine = "^4.0.2" mkdocs = "^1.4.3" mkdocstrings = {extras = ["python"], version = "^0.22.0"} -mkdocs-material = "^9.1.19" +mkdocs-material = "^9.1.21" pytest = "^7.4.0" httpx = "^0.24.1" scikit-build = "0.17.6" From 0cc8d8282ab3ea2324fb0d48da1032c8c6f136c4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Aug 2023 21:25:18 +0000 Subject: [PATCH 13/38] Bump uvicorn from 0.23.1 to 0.23.2 Bumps [uvicorn](https://github.com/encode/uvicorn) from 0.23.1 to 0.23.2. - [Release notes](https://github.com/encode/uvicorn/releases) - [Changelog](https://github.com/encode/uvicorn/blob/master/CHANGELOG.md) - [Commits](https://github.com/encode/uvicorn/compare/0.23.1...0.23.2) --- updated-dependencies: - dependency-name: uvicorn dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3cfabb8..1dcbfe6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1655,13 +1655,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.23.1" +version = "0.23.2" description = "The lightning-fast ASGI server." optional = true python-versions = ">=3.8" files = [ - {file = "uvicorn-0.23.1-py3-none-any.whl", hash = "sha256:1d55d46b83ee4ce82b4e82f621f2050adb3eb7b5481c13f9af1744951cae2f1f"}, - {file = "uvicorn-0.23.1.tar.gz", hash = "sha256:da9b0c8443b2d7ee9db00a345f1eee6db7317432c9d4400f5049cc8d358383be"}, + {file = "uvicorn-0.23.2-py3-none-any.whl", hash = "sha256:1f9be6558f01239d4fdf22ef8126c39cb1ad0addf76c40e760549d2c2f43ab53"}, + {file = "uvicorn-0.23.2.tar.gz", hash = "sha256:4d3cc12d7727ba72b64d12d3cc7743124074c0a69f7b201512fc50c3e3f1569a"}, ] [package.dependencies] @@ -1757,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "00bcb182a7f4e32ac8e7f6559f37e8a06fb911bac4b8556b8cfdc9201c945d94" +content-hash = "6718d680fa89f9518a232c1110ba43958d3e21c54c4dbd9129effa4f40a02b81" diff --git a/pyproject.toml b/pyproject.toml index 9cf9710..e3fcd0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ python = "^3.8.1" typing-extensions = "^4.7.1" numpy = "^1.24.4" diskcache = "^5.6.1" -uvicorn = { version = "^0.23.1", optional = true } +uvicorn = { version = "^0.23.2", optional = true } fastapi = { version = ">=0.100.0", optional = true } sse-starlette = { version = ">=1.6.1", optional = true } pydantic-settings = { version = ">=2.0.1", optional = true } From 39978ccaf5b8ca85bc6b72d719e746ea305ad37f Mon Sep 17 00:00:00 2001 From: bretello Date: Thu, 3 Aug 2023 18:22:52 +0200 Subject: [PATCH 14/38] add `mul_mat_q` parameter This also fixes a crash when loading the 70b llama2 model on MacOS with metal and `n_gpu_layers=1` --- llama_cpp/llama_cpp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 423a4a0..bbb2a1e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -181,6 +181,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // Keep the booleans together to avoid misalignment during copy-by-value. # bool low_vram; // if true, reduce VRAM usage at the cost of performance +# bool mul_mat_q; // if true, use experimental mul_mat_q kernels # bool f16_kv; // use fp16 for KV cache # bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool vocab_only; // only load the vocabulary, no weights @@ -203,6 +204,7 @@ class llama_context_params(Structure): ("progress_callback", llama_progress_callback), ("progress_callback_user_data", c_void_p), ("low_vram", c_bool), + ("mul_mat_q", c_bool), ("f16_kv", c_bool), ("logits_all", c_bool), ("vocab_only", c_bool), From 9f499af6b0253273d03834eac6b36c5767c57d48 Mon Sep 17 00:00:00 2001 From: bretello Date: Thu, 3 Aug 2023 18:23:26 +0200 Subject: [PATCH 15/38] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 41c6741..8183159 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 41c674161fb2459bdf7806d1eebead15bc5d046e +Subproject commit 8183159cf3def112f6d1fe94815fce70e1bffa12 From ac188a21f3a2e9530b36e6103de36f5ba655376e Mon Sep 17 00:00:00 2001 From: c0sogi Date: Sat, 5 Aug 2023 14:43:35 +0900 Subject: [PATCH 16/38] Added low level grammar API --- llama_cpp/llama_cpp.py | 34 + llama_cpp/llama_grammar.py | 1331 ++++++++++++++++++++++++++++++++++++ 2 files changed, 1365 insertions(+) create mode 100644 llama_cpp/llama_grammar.py diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 423a4a0..d9a68a9 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1157,6 +1157,23 @@ _lib.llama_sample_temperature.argtypes = [ _lib.llama_sample_temperature.restype = None +# LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); +def llama_sample_grammar( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + grammar, # type: llama_grammar_p +): + return _lib.llama_sample_grammar(ctx, candidates, grammar) + + +_lib.llama_sample_grammar.argtypes = [ + llama_context_p, + llama_token_data_array_p, + llama_grammar_p, +] +_lib.llama_sample_grammar.restype = None + + # @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1244,6 +1261,23 @@ _lib.llama_sample_token.argtypes = [ _lib.llama_sample_token.restype = llama_token +# /// @details Accepts the sampled token into the grammar +# LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); +def llama_grammar_accept_token( + ctx: llama_context_p, + grammar: llama_grammar_p, + token: llama_token, +) -> None: + _lib.llama_grammar_accept_token(ctx, grammar, token) + + +_lib.llama_grammar_accept_token.argtypes = [ + llama_context_p, + llama_grammar_p, + llama_token, +] +_lib.llama_grammar_accept_token.restype = None + # Performance information diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py new file mode 100644 index 0000000..07a120f --- /dev/null +++ b/llama_cpp/llama_grammar.py @@ -0,0 +1,1331 @@ +"""C++ implementation of the llama grammar parser.""" +# flake8: noqa +import argparse +from pathlib import Path +import sys +from ctypes import Array, c_int, c_size_t, c_uint32, cast +from enum import Enum +from itertools import islice +from typing import ( + Callable, + Generic, + List, + Optional, + OrderedDict, + TextIO, + Tuple, + TypeVar, + Union, +) + +import llama_cpp + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") +size_t = uint8_t = uint32_t = int +static_cast_uint8_t = ord + + +class Sentinel: + pass + + +class const_char_p: + """C++ implementation of const char*.""" + + def __init__(self, value: Union[str, "const_char_p"]): + if isinstance(value, const_char_p): + # We're copying an existing const_char_p + self.value = value.value + self.pos = value.pos + return + + # We're creating a new const_char_p + self.value = value + self.pos = 0 + + def __str__(self) -> str: + return self.value[self.pos :] + + def __add__(self, increment: int) -> "const_char_p": + # To avoid side effects, we create a new const_char_p object + new = self.__class__(self.value) + new.pos = self.pos + increment + return new + + def __sub__(self, decrement: int) -> "const_char_p": + # To avoid side effects, we create a new const_char_p object + new = self.__class__(self.value) + new.pos = self.pos - decrement + return new + + def __lt__(self, other: "const_char_p") -> bool: + return self.pos < other.pos and self.value == other.value + + def __gt__(self, other: "const_char_p") -> bool: + return self.pos > other.pos and self.value == other.value + + def __eq__(self, other: "const_char_p") -> bool: + return self.pos == other.pos and self.value == other.value + + def add(self, other: "const_char_p") -> int: + if self.value != other.value: + raise ValueError("Can't add pointers to different strings") + return self.pos + other.pos + + def sub(self, other: "const_char_p") -> int: + if self.value != other.value: + raise ValueError("Can't subtract pointers to different strings") + return self.pos - other.pos + + def plus_plus(self) -> None: + self.pos += 1 + + def minus_minus(self) -> None: + self.pos -= 1 + + @property + def derefer(self) -> Optional[str]: + if self.pos >= len(self.value): + # We've reached the end of the string + return None + + return self.value[self.pos] + + +class std__vector(Generic[T], List[T]): + """C++ implementation of std::vector.""" + + class iterator: + def __init__(self, vector: "std__vector[T]", index: int): + self._vector = vector + self._index = index + self._version = vector._version + + def _check_version(self): + if self._version != self._vector._version: + raise RuntimeError("Iterator used after vector was modified.") + + def __iter__(self): + return self + + def __next__(self) -> T: + self._check_version() + if self._index >= self._vector.size(): + raise StopIteration + value = self._vector[self._index] + self._index += 1 + return value + + def __add__(self, value: int) -> "std__vector[T].iterator": + return self.__class__(self._vector, self._index + value) + + def __sub__(self, value: int) -> "std__vector[T].iterator": + return self.__class__(self._vector, self._index - value) + + def __init__(self): + self._version = 0 + + def modify(self): + # This is a bit of a hack to make sure iterators are invalidated + self._version += 1 + + def push_back(self, value: T) -> None: + self.modify() + self.append(value) + + def pop_back(self) -> None: + self.modify() + if not self.empty(): + self.pop() + + def back(self) -> T: + return self[-1] + + def size(self) -> int: + return len(self) + + # def clear(self) -> None: + # super().clear() + + def empty(self) -> bool: + return self.size() == 0 + + def data(self) -> "std__vector[T]": + return self + + def resize( + self, + new_size: int, + fill_value_factory: Optional[Callable[[], T]] = None, + ) -> None: + if new_size > self.size(): + if fill_value_factory is None: + raise ValueError( + "A fill value factory function must be provided." + ) + self.reserve(new_size, fill_value_factory) + elif new_size < self.size(): + self[:] = self[:new_size] + + def reserve( + self, capacity: int, fill_value_factory: Callable[[], T] + ) -> None: + if capacity > self.size(): + fill_value = fill_value_factory() + self.extend([fill_value] * (capacity - self.size())) + + def front(self) -> T: + if not self.empty(): + return self[0] + else: + raise IndexError("Vector is empty.") + + def assign(self, count: int, value: T) -> None: + self.clear() + self.extend([value] * count) + + def insert( + self, + pos: "std__vector[T].iterator", + first: "std__vector[T].iterator", + last: "std__vector[T].iterator", + ) -> None: + self[pos._index : pos._index] = list( + islice(first._vector, first._index, last._index) + ) + + def begin(self) -> "std__vector[T].iterator": + return self.iterator(self, 0) + + def end(self) -> "std__vector[T].iterator": + return self.iterator(self, self.size()) + + +class std__map(Generic[T, U], OrderedDict[T, U]): + """C++ implementation of std::map.""" + + class iterator(Generic[V, W]): + def __init__(self, _map: "std__map[T, U]", key: Union[T, Sentinel]): + self._map = _map + self.iter = iter(_map) + self.key = key + self._advance() + + def _sanitize_key(self) -> T: + if isinstance(self.key, Sentinel): + raise StopIteration + return self.key + + def _advance(self) -> None: + try: + while next(self.iter) != self.key: + pass + except StopIteration: + self.key = Sentinel() + + def __next__(self) -> Tuple[T, U]: + key = self._sanitize_key() + if key in self._map: + value = self._map[key] + self._advance() + return key, value + else: + raise StopIteration + + def get(self) -> Tuple[T, U]: + key = self._sanitize_key() + return key, self._map[key] + + @property + def first(self) -> T: + return self._sanitize_key() + + @property + def second(self) -> U: + return self._map[self._sanitize_key()] + + def insert( + self, key: T, value: U + ) -> Tuple["std__map[T, U].iterator[T, U]", bool]: + if key in self: + return self.iterator(self, key), False + else: + self[key] = value + return self.iterator(self, key), True + + def find(self, key: T) -> "std__map[T, U].iterator[T, U]": + if key in self: + return self.iterator(self, key) + else: + return self.end() + + def at(self, key: T) -> U: + if key in self: + return self[key] + else: + raise KeyError("The provided key is not found in the map.") + + def erase(self, iterator: "std__map[T, U].iterator[T, U]") -> None: + key = iterator.first + if key in self: + del self[key] + + def size(self) -> int: + return len(self) + + def empty(self) -> bool: + return self.size() == 0 + + def lower_bound(self, key: T) -> "std__map[T, U].iterator[T, U]": + try: + keys = sorted(list(self.keys())) # type: ignore + for k in keys: + if k >= key: + return self.iterator(self, k) + raise ValueError( + "No key found that is not less than the input key" + ) + except TypeError: + raise TypeError("Keys of type T cannot be sorted.") + + def begin(self) -> "std__map[T, U].iterator[T, U]": + return self.iterator(self, next(iter(self))) + + def end(self) -> "std__map[T, U].iterator[T, U]": + return self.iterator(self, Sentinel()) + + +class std__string(str): + def __new__(cls, ptr: const_char_p, length: Optional[int] = None): + if length is not None: + return super().__new__(cls, str(ptr)[:length]) + return super().__new__(cls, str(ptr)) + + +# // grammar element type +# enum llama_gretype { +# // end of rule definition +# LLAMA_GRETYPE_END = 0, + +# // start of alternate definition for rule +# LLAMA_GRETYPE_ALT = 1, + +# // non-terminal element: reference to rule +# LLAMA_GRETYPE_RULE_REF = 2, + +# // terminal element: character (code point) +# LLAMA_GRETYPE_CHAR = 3, + +# // inverse char(s) ([^a], [^a-b] [^abc]) +# LLAMA_GRETYPE_CHAR_NOT = 4, + +# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to +# // be an inclusive range ([a-z]) +# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + +# // modifies a preceding LLAMA_GRETYPE_CHAR or +# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) +# LLAMA_GRETYPE_CHAR_ALT = 6, +# }; +class llama_gretype(Enum): + """grammar element type""" + + LLAMA_GRETYPE_END = 0 # end of rule definition + LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule + LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule + LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) + LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + + +# typedef struct llama_grammar_element { +# enum llama_gretype type; +# uint32_t value; // Unicode code point or rule ID +# } llama_grammar_element; + + +# class llama_grammar_element(Structure): +# _fields_ = [ +# ("type", c_int), +# ("value", c_uint32), +# ] + + +class llama_grammar_element: + def __init__(self, type: llama_gretype, value: uint32_t): + self.type = type + self.value = value # Unicode code point or rule ID + + def __repr__(self): # debug + return f"llama_grammar_element({self.type}, {self.value})" + + +# struct parse_state { +# std::map symbol_ids; +# std::vector> rules; +# std::vector c_rules(); +# }; +class parse_state: + def __init__(self): + self.symbol_ids: std__map[str, uint32_t] = std__map() + self.rules: std__vector[ + std__vector[llama_grammar_element] + ] = std__vector() + + # std::vector parse_state::c_rules() { + # std::vector ret; + # for (const auto & rule : rules) { + # ret.push_back(rule.data()); + # } + # return ret; + # } + def c_rules(self) -> std__vector[std__vector[llama_grammar_element]]: + ret = ( + std__vector() + ) # type: std__vector[std__vector[llama_grammar_element]] + for rule in self.rules: + ret.push_back(rule.data()) + return ret + + +# struct llama_grammar { +# const std::vector> rules; +# std::vector> stacks; +# }; +class llama_grammar: + def __init__( + self, + rules: std__vector[std__vector[llama_grammar_element]], + stacks: std__vector[std__vector[llama_grammar_element]], + ): + self.rules = rules + self.stacks = stacks + + +# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); +# return result.first->second; +# } +def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: + next_id = uint32_t(state.symbol_ids.size()) # type: uint32_t + result = state.symbol_ids.insert(str(std__string(src, len)), next_id) + return result[0].second # type: ignore + + +# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; +# return next_id; +# } +def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: + next_id = state.symbol_ids.size() # type: uint32_t + state.symbol_ids[base_name + "_" + str(next_id)] = next_id + return next_id + + +# void add_rule( +# parse_state & state, +# uint32_t rule_id, +# const std::vector & rule) { +# if (state.rules.size() <= rule_id) { +# state.rules.resize(rule_id + 1); +# } +# state.rules[rule_id] = rule; +# } +def add_rule( + state: parse_state, + rule_id: uint32_t, + rule: std__vector[llama_grammar_element], +) -> None: + if state.rules.size() <= rule_id: + state.rules.resize( + rule_id + 1, fill_value_factory=std__vector[llama_grammar_element] + ) + state.rules[rule_id] = rule + + +# std::pair decode_utf8(const char * src) { +# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; +# uint8_t first_byte = static_cast(*src); +# uint8_t highbits = first_byte >> 4; +# int len = lookup[highbits]; +# uint8_t mask = (1 << (8 - len)) - 1; +# uint32_t value = first_byte & mask; +# const char * end = src + len; // may overrun! +# const char * pos = src + 1; +# for ( ; pos < end && *pos; pos++) { +# value = (value << 6) + (static_cast(*pos) & 0x3F); +# } +# return std::make_pair(value, pos); +# } +def decode_utf8(src: const_char_p) -> Tuple[uint32_t, const_char_p]: + """Decodes a UTF-8 character from the source string.""" + lookup = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4) + first_byte = static_cast_uint8_t(src.derefer or "") # type: uint8_t + highbits = first_byte >> 4 # type: uint8_t + len = lookup[highbits] # type: int + mask = (1 << (8 - len)) - 1 # type: uint8_t + value = first_byte & mask # type: uint32_t + end = src + len # type: const_char_p # may overrun! + pos = src + 1 # type: const_char_p + while pos < end and pos.derefer: + value = (value << 6) + (static_cast_uint8_t(src.derefer or "") & 0x3F) + pos.plus_plus() + return value, pos + + +# bool is_word_char(char c) { +# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); +# } +def is_word_char(c: str) -> bool: + return ( + ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") + ) + + +# std::pair parse_hex(const char * src, int size) { +# const char * pos = src; +# const char * end = src + size; +# uint32_t value = 0; +# for ( ; pos < end && *pos; pos++) { +# value <<= 4; +# char c = *pos; +# if ('a' <= c && c <= 'f') { +# value += c - 'a' + 10; +# } else if ('A' <= c && c <= 'F') { +# value += c - 'A' + 10; +# } else if ('0' <= c && c <= '9') { +# value += c - '0'; +# } else { +# break; +# } +# } +# if (pos != end) { +# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); +# } +# return std::make_pair(value, pos); +# } +def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: + pos = const_char_p(src) # type: const_char_p + end = src + size # type: const_char_p + value = 0 # type: uint32_t + while pos < end and pos.derefer: + value <<= 4 + c = pos.derefer # type: str + if "a" <= c <= "f": + value += static_cast_uint8_t(c) - static_cast_uint8_t("a") + 10 + elif "A" <= c <= "F": + value += static_cast_uint8_t(c) - static_cast_uint8_t("A") + 10 + elif "0" <= c <= "9": + value += static_cast_uint8_t(c) - static_cast_uint8_t("0") + else: + break + pos.plus_plus() + if pos != end: + raise RuntimeError( + "expecting " + str(size) + " hex chars at " + str(src) + ) + return (value, pos) + + +# std::pair parse_char(const char * src) { +# if (*src == '\\') { +# switch (src[1]) { +# case 'x': return parse_hex(src + 2, 2); +# case 'u': return parse_hex(src + 2, 4); +# case 'U': return parse_hex(src + 2, 8); +# case 't': return std::make_pair('\t', src + 2); +# case 'r': return std::make_pair('\r', src + 2); +# case 'n': return std::make_pair('\n', src + 2); +# case '\\': +# case '"': +# case '[': +# case ']': +# return std::make_pair(src[1], src + 2); +# default: +# throw std::runtime_error(std::string("unknown escape at ") + src); +# } +# } else if (*src) { +# return decode_utf8(src); +# } +# throw std::runtime_error("unexpected end of input"); +# } +def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: + if src.derefer == "\\": + switch = (src + 1).derefer # type: Optional[str] + if switch == "x": + return parse_hex(src + 2, 2) + elif switch == "u": + return parse_hex(src + 2, 4) + elif switch == "U": + return parse_hex(src + 2, 8) + elif switch == "t": + return (static_cast_uint8_t("\t"), src + 2) # implicit cast + elif switch == "r": + return (static_cast_uint8_t("\r"), src + 2) # implicit cast + elif switch == "n": + return (static_cast_uint8_t("\n"), src + 2) # implicit cast + elif switch in ("\\", '"', "[", "]"): + return (static_cast_uint8_t(switch), src + 2) # implicit cast + else: + raise RuntimeError("unknown escape at " + str(src)) + elif src.derefer: + return decode_utf8(src) + else: + raise RuntimeError("unexpected end of input") + + +# const char * parse_name(const char * src) { +# const char * pos = src; +# while (is_word_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting name at ") + src); +# } +# return pos; +# } +def parse_name(src: const_char_p) -> const_char_p: + pos = const_char_p(src) # type: const_char_p + while is_word_char(pos.derefer or ""): + pos.plus_plus() + if pos == src: + raise RuntimeError("expecting name at " + str(src)) + return pos + + +# const char * parse_space(const char * src, bool newline_ok) { +# const char * pos = src; +# while (*pos == ' ' || *pos == '\t' || *pos == '#' || +# (newline_ok && (*pos == '\r' || *pos == '\n'))) { +# if (*pos == '#') { +# while (*pos && *pos != '\r' && *pos != '\n') { +# pos++; +# } +# } else { +# pos++; +# } +# } +# return pos; +# } +def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: + # Using a copy of `src` to avoid side effects + pos = const_char_p(src) + + while pos.derefer in (" ", "\t", "#") or ( + newline_ok and pos.derefer in ("\r", "\n") + ): + if pos.derefer == "#": + while pos.derefer is not None and pos.derefer not in ("\r", "\n"): + pos.plus_plus() + else: + pos.plus_plus() + + return pos + + +# const char * parse_sequence( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# std::vector & out_elements, +# bool is_nested) { +def parse_sequence( + state: parse_state, + src: const_char_p, + rule_name: str, + out_elements: std__vector[llama_grammar_element], + is_nested: bool, +) -> const_char_p: + # size_t last_sym_start = out_elements.size(); + # const char * pos = src; + last_sym_start = out_elements.size() # type: size_t + pos = const_char_p(src) # type: const_char_p + # while (*pos) { + while pos.derefer: + # if (*pos == '"') { // literal string + # pos++; + # last_sym_start = out_elements.size(); + # while (*pos != '"') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + # } + # pos = parse_space(pos + 1, is_nested); + if pos.derefer == '"': # literal string + pos.plus_plus() + last_sym_start = out_elements.size() + while pos.derefer != '"': + char_pair = parse_char( + pos + ) # type: Tuple[uint32_t, const_char_p] + pos = char_pair[1] + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0] + ) + ) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '[') { // char range(s) + # pos++; + # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + elif pos.derefer == "[": # char range(s) + pos.plus_plus() + start_type = ( + llama_gretype.LLAMA_GRETYPE_CHAR + ) # type: llama_gretype + # if (*pos == '^') { + # pos++; + # start_type = LLAMA_GRETYPE_CHAR_NOT; + # } + # last_sym_start = out_elements.size(); + if pos.derefer == "^": + pos.plus_plus() + start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT + last_sym_start = out_elements.size() + # while (*pos != ']') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # enum llama_gretype type = last_sym_start < out_elements.size() + # ? LLAMA_GRETYPE_CHAR_ALT + # : start_type; + # out_elements.push_back({type, char_pair.first}); + while pos.derefer != "]": + char_pair = parse_char( + pos + ) # type: Tuple[uint32_t, const_char_p] + pos = char_pair[1] + type = ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT + if last_sym_start < out_elements.size() + else start_type + ) # type: llama_gretype + out_elements.push_back( + llama_grammar_element(type, char_pair[0]) + ) + # if (pos[0] == '-' && pos[1] != ']') { + # auto endchar_pair = parse_char(pos + 1); + # pos = endchar_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + # } + # } + if pos.derefer == "-" and (pos + 1).derefer != "]": + endchar_pair = parse_char( + pos + 1 + ) # type: Tuple[uint32_t, const_char_p] + pos = endchar_pair[1] + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + endchar_pair[0], + ) + ) + # pos = parse_space(pos + 1, is_nested); + pos = parse_space(pos + 1, is_nested) + # } else if (is_word_char(*pos)) { // rule reference + # const char * name_end = parse_name(pos); + # uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + # pos = parse_space(name_end, is_nested); + # last_sym_start = out_elements.size(); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + elif is_word_char(pos.derefer): # rule reference + name_end = parse_name(pos) # type: const_char_p + ref_rule_id = get_symbol_id( + state, pos, name_end.sub(pos) + ) # type: uint32_t + pos = parse_space(name_end, is_nested) + last_sym_start = out_elements.size() + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id + ) + ) + # } else if (*pos == '(') { // grouping + # // parse nested alternates into synthesized rule + # pos = parse_space(pos + 1, true); + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + # last_sym_start = out_elements.size(); + # // output reference to synthesized rule + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # if (*pos != ')') { + # throw std::runtime_error(std::string("expecting ')' at ") + pos); + # } + # pos = parse_space(pos + 1, is_nested); + elif pos.derefer == "(": # grouping + pos = parse_space(pos + 1, True) + sub_rule_id = generate_symbol_id( + state, rule_name + ) # type: uint32_t + pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) + last_sym_start = out_elements.size() + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + if pos.derefer != ")": + raise RuntimeError("expecting ')' at " + str(pos)) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + # if (last_sym_start == out_elements.size()) { + # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + # } + elif pos.derefer in ("*", "+", "?"): # repetition operator + if last_sym_start == out_elements.size(): + raise RuntimeError( + "expecting preceding item to */+/? at " + str(pos) + ) + # // apply transformation to previous symbol (last_sym_start to end) according to + # // rewrite rules: + # // S* --> S' ::= S S' | + # // S+ --> S' ::= S S' | S + # // S? --> S' ::= S | + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # std::vector sub_rule; + # // add preceding symbol to generated rule + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + sub_rule_id = generate_symbol_id( + state, rule_name + ) # type: uint32_t + sub_rule = std__vector[llama_grammar_element]() + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + # if (*pos == '*' || *pos == '+') { + # // cause generated rule to recurse + # sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # } + # // mark start of alternate def + # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if pos.derefer in ("*", "+"): + sub_rule.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + sub_rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + ) + # if (*pos == '+') { + # // add preceding symbol as alternate only for '+' (otherwise empty) + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + # } + # sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + # add_rule(state, sub_rule_id, sub_rule); + # // in original rule, replace previous symbol with reference to generated rule + # out_elements.resize(last_sym_start); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # pos = parse_space(pos + 1, is_nested); + if pos.derefer == "+": + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + sub_rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0) + ) + add_rule(state, sub_rule_id, sub_rule) + out_elements.resize(last_sym_start) + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + pos = parse_space(pos + 1, is_nested) + # } else { + # break; + # } + else: + break + # } + # return pos; + # } + return pos + + +# const char * parse_alternates( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# uint32_t rule_id, +# bool is_nested) { +# std::vector rule; +# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); +# while (*pos == '|') { +# rule.push_back({LLAMA_GRETYPE_ALT, 0}); +# pos = parse_space(pos + 1, true); +# pos = parse_sequence(state, pos, rule_name, rule, is_nested); +# } +# rule.push_back({LLAMA_GRETYPE_END, 0}); +# add_rule(state, rule_id, rule); +# return pos; +# } +def parse_alternates( + state: parse_state, + src: const_char_p, + rule_name: str, + rule_id: uint32_t, + is_nested: bool, +) -> const_char_p: + rule = std__vector() # type: std__vector[llama_grammar_element] + pos = parse_sequence( + state, src, rule_name, rule, is_nested + ) # type: const_char_p + while pos.derefer == "|": + rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + ) + pos = parse_space(pos + 1, True) + pos = parse_sequence(state, pos, rule_name, rule, is_nested) + rule.push_back(llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0)) + add_rule(state, rule_id, rule) + return pos + + +# const char * parse_rule(parse_state & state, const char * src) { +# const char * name_end = parse_name(src); +# const char * pos = parse_space(name_end, false); +# size_t name_len = name_end - src; +# uint32_t rule_id = get_symbol_id(state, src, name_len); +# const std::string name(src, name_len); + +# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { +# throw std::runtime_error(std::string("expecting ::= at ") + pos); +# } +# pos = parse_space(pos + 3, true); + +# pos = parse_alternates(state, pos, name, rule_id, false); + + +# if (*pos == '\r') { +# pos += pos[1] == '\n' ? 2 : 1; +# } else if (*pos == '\n') { +# pos++; +# } else if (*pos) { +# throw std::runtime_error(std::string("expecting newline or end at ") + pos); +# } +# return parse_space(pos, true); +# } +def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: + name_end = parse_name(src) # type: const_char_p + pos = parse_space(name_end, False) # type: const_char_p + name_len = name_end.sub(src) # type: size_t + rule_id = get_symbol_id(state, src, name_len) # type: uint32_t + name = std__string(src, name_len) # type: std__string + + if not ( + pos.derefer == ":" + and (pos + 1).derefer == ":" + and (pos + 2).derefer == "=" + ): + raise RuntimeError("expecting ::= at " + str(pos)) + + pos = parse_space(pos + 3, True) # type: const_char_p + pos = parse_alternates( + state, pos, name, rule_id, False + ) # type: const_char_p + + if pos.derefer == "\r": + pos += 2 if (pos + 1).derefer == "\n" else 1 + elif pos.derefer == "\n": + pos.plus_plus() + elif pos.derefer: + raise RuntimeError("expecting newline or end at " + str(pos)) + return parse_space(pos, True) + + +# parse_state parse(const char * src) { +# try { +# parse_state state; +# const char * pos = parse_space(src, true); +# while (*pos) { +# pos = parse_rule(state, pos); +# } +# return state; +# } catch (const std::exception & err) { +# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); +# return parse_state(); +# } +# } +def parse(src: const_char_p) -> parse_state: + try: + state = parse_state() # type: parse_state + pos = parse_space(src, True) # type: const_char_p + while pos.derefer: + pos = parse_rule(state, pos) + return state + except Exception as err: + print(f"{parse.__name__}: error parsing grammar: {err}") + return parse_state() + + +# void print_grammar_char(FILE * file, uint32_t c) { +# if (0x20 <= c && c <= 0x7f) { +# fprintf(file, "%c", static_cast(c)); +# } else { +# // cop out of encoding UTF-8 +# fprintf(file, "", c); +# } +# } +def print_grammar_char(file: TextIO, c: uint32_t) -> None: + if 0x20 <= c and c <= 0x7F: + file.write(chr(c)) + else: + # cop out of encoding UTF-8 + file.write(f"") + + +# bool is_char_element(llama_grammar_element elem) { +# switch (elem.type) { +# case LLAMA_GRETYPE_CHAR: return true; +# case LLAMA_GRETYPE_CHAR_NOT: return true; +# case LLAMA_GRETYPE_CHAR_ALT: return true; +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; +# default: return false; +# } +# } +def is_char_element(elem: llama_grammar_element) -> bool: + return elem.type in ( + llama_gretype.LLAMA_GRETYPE_CHAR, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ) + + +# void print_rule( +# FILE * file, +# uint32_t rule_id, +# const std::vector & rule, +# const std::map & symbol_id_names) { +def print_rule( + file: TextIO, + rule_id: uint32_t, + rule: std__vector[llama_grammar_element], + symbol_id_names: std__map[uint32_t, str], +) -> None: + # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + # throw std::runtime_error( + # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + # } + # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + + str(rule_id) + ) + print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") + # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + # llama_grammar_element elem = rule[i]; + # switch (elem.type) { + # case LLAMA_GRETYPE_END: + # throw std::runtime_error( + # "unexpected end of rule: " + std::to_string(rule_id) + "," + + # std::to_string(i)); + # case LLAMA_GRETYPE_ALT: + # fprintf(file, "| "); + # break; + # case LLAMA_GRETYPE_RULE_REF: + # fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + # break; + # case LLAMA_GRETYPE_CHAR: + # fprintf(file, "["); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_NOT: + # fprintf(file, "[^"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # fprintf(file, "-"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_ALT: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # print_grammar_char(file, elem.value); + # break; + # } + for i, elem in enumerate(rule[:-1]): + switch = elem.type # type: llama_gretype + if switch == llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "unexpected end of rule: " + str(rule_id) + "," + str(i) + ) + elif switch == llama_gretype.LLAMA_GRETYPE_ALT: + print("| ", file=file, end="") + elif switch == llama_gretype.LLAMA_GRETYPE_RULE_REF: + print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR: + print("[", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_NOT: + print("[^", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print("-", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print_grammar_char(file, elem.value) + # if (is_char_element(elem)) { + # switch (rule[i + 1].type) { + # case LLAMA_GRETYPE_CHAR_ALT: + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # break; + # default: + # fprintf(file, "] "); + if is_char_element(elem): + if rule[i + 1].type in ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ): + pass + else: + print("] ", file=file, end="") + # } + # } + # } + # fprintf(file, "\n"); + # } + print(file=file) + + +# void print_grammar(FILE * file, const parse_state & state) { +# try { +# std::map symbol_id_names; +# for (auto kv : state.symbol_ids) { +# symbol_id_names[kv.second] = kv.first; +# } +# for (size_t i = 0, end = state.rules.size(); i < end; i++) { +# // fprintf(file, "%zu: ", i); +# // print_rule_binary(file, state.rules[i]); +# print_rule(file, i, state.rules[i], symbol_id_names); +# // fprintf(file, "\n"); +# } +# } catch (const std::exception & err) { +# fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); +# } +# } +def print_grammar(file: TextIO, state: parse_state) -> None: + try: + symbol_id_names = std__map() # type: std__map[uint32_t, str] + for kv in state.symbol_ids.items(): + symbol_id_names[kv[1]] = kv[0] + + for i, rule in enumerate(state.rules): + print_rule(file, i, rule, symbol_id_names) + except Exception as err: + print( + f"{print_grammar.__name__}: error printing grammar: {err}", + file=sys.stderr, + ) + + +def convert_to_rules( + llama_grammar_elements: std__vector[std__vector[llama_grammar_element]], +) -> Array[llama_cpp.llama_grammar_element_p]: + """Make an Array object that is used for `llama_grammer_init`""" + + # Step 1: Convert to c_llama_grammar_element + llama_grammar_element_p_p = ( + [] + ) # type: List[List[llama_cpp.llama_grammar_element]] + for subvector in llama_grammar_elements: + llama_grammar_element_p_p.append([]) + for elem in subvector: + c_llama_grammar_element = llama_cpp.llama_grammar_element() + c_llama_grammar_element.type = c_int(elem.type.value) + c_llama_grammar_element.value = c_uint32(elem.value) + llama_grammar_element_p_p[-1].append(c_llama_grammar_element) + + # Step 2: Convert each list to llama_grammar_element array and get pointer + element_arrays = [ + (llama_cpp.llama_grammar_element * len(sublist))(*sublist) + for sublist in llama_grammar_element_p_p + ] # type: List[Array[llama_cpp.llama_grammar_element]] + + # Step 3: Get pointer of each array + element_array_pointers = [ + cast(sublist, llama_cpp.llama_grammar_element_p) + for sublist in element_arrays + ] # type: List[llama_cpp.llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + return (llama_cpp.llama_grammar_element_p * len(element_array_pointers))( + *element_array_pointers + ) + + +def parse_grammar_init_args( + bnf: str, +) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: + """Parse a GBNF string and return tuple of `grammar rules` and `root symbol id`""" + parsed_grammar = parse(const_char_p(bnf)) # type: parse_state + if parsed_grammar.rules.empty(): + raise Exception( + f"{parse_grammar_init_args.__name__}: error parsing grammar file: parsed_grammar.rules is empty" + ) + print(f"{parse_grammar_init_args.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) + grammar_rules = ( + parsed_grammar.c_rules() + ) # type: std__vector[std__vector[llama_grammar_element]] + return ( + convert_to_rules(grammar_rules), + c_size_t(grammar_rules.size()), + c_size_t(parsed_grammar.symbol_ids.at("root")), + ) + + +def parse_grammar_init_args_from_file( + bnf_path: Union[str, Path] +) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: + """Parse a GBNF file and return tuple of `grammar rules` and `root symbol id`""" + try: + with open(bnf_path) as f: + params_grammer = f.read() + except Exception as err: + raise Exception( + f"{parse_grammar_init_args_from_file.__name__}: error reading grammar file: {err}" + ) + + if params_grammer: + return parse_grammar_init_args(params_grammer) + + raise Exception( + f"{parse_grammar_init_args_from_file.__name__}: error parsing grammar file: params_grammer is empty" + ) + + +# def get_grammar_p(bnf: str) -> llama_cpp.llama_grammar_p: +# """Parse a GBNF string and return pointer to `llama_grammar`""" + +# grammar_rules, root_symbol_id = parse_rules(bnf) + +# grammar_element_p_p = convert_to_double_ptr( +# grammar_rules +# ) # type: llama_cpp.llama_grammar_element_p_p + +# c_llama_grammar_p = llama_cpp.llama_grammar_init( +# grammar_element_p_p, +# c_size_t(grammar_rules.size()), +# c_size_t(root_symbol_id), +# ) # type: llama_cpp.llama_grammar_p +# return c_llama_grammar_p + + +# def get_grammar_p_from_file( +# bnf_path: Union[str, Path] +# ) -> llama_cpp.llama_grammar_p: +# """Parse a GBNF file and return pointer to `llama_grammar`""" +# try: +# with open(bnf_path) as f: +# params_grammer = f.read() +# except Exception as err: +# raise Exception( +# f"{get_grammar_p_from_file.__name__}: error reading grammar file: {err}" +# ) + +# if params_grammer: +# return get_grammar_p(params_grammer) + +# raise Exception( +# f"{get_grammar_p_from_file.__name__}: error parsing grammar file: params_grammer is empty" +# ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate C++ parser from GBNF grammar" + ) + parser.add_argument( + "-g", + "--grammar", + type=str, + default="./vendor/llama.cpp/grammars/json.gbnf", + help="path to GBNF grammar file", + ) + + args = parser.parse_args() + rules, n_rules, start_rule_index = parse_grammar_init_args_from_file( + args.grammar + ) + llama_grammar_p = llama_cpp.llama_grammar_init( + rules, + n_rules, + start_rule_index, + ) # type: llama_cpp.llama_grammar_p + + # ----- USAGE: + # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) + # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...) + + # ----- SAMPLE OUTPUT: + # main grammar: + # root ::= object + # object ::= [{] ws object_11 [}] ws + # value ::= object | array | string | number | value_6 ws + # array ::= [[] ws array_15 []] ws + # string ::= ["] string_18 ["] ws + # number ::= number_19 number_25 number_29 ws + # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l] + # ws ::= ws_31 + # object_8 ::= string [:] ws value object_10 + # object_9 ::= [,] ws string [:] ws value + # object_10 ::= object_9 object_10 | + # object_11 ::= object_8 | + # array_12 ::= value array_14 + # array_13 ::= [,] ws value + # array_14 ::= array_13 array_14 | + # array_15 ::= array_12 | + # string_16 ::= [^"\] | [\] string_17 + # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] + # string_18 ::= string_16 string_18 | + # number_19 ::= number_20 number_21 + # number_20 ::= [-] | + # number_21 ::= [0-9] | [1-9] number_22 + # number_22 ::= [0-9] number_22 | + # number_23 ::= [.] number_24 + # number_24 ::= [0-9] number_24 | [0-9] + # number_25 ::= number_23 | + # number_26 ::= [eE] number_27 number_28 + # number_27 ::= [-+] | + # number_28 ::= [0-9] number_28 | [0-9] + # number_29 ::= number_26 | + # ws_30 ::= [ ] ws + # ws_31 ::= ws_30 | From 097fba25e53866beb08d1cff250a00d75e178127 Mon Sep 17 00:00:00 2001 From: Mike Zeng Date: Sat, 5 Aug 2023 02:00:04 -0500 Subject: [PATCH 17/38] Fixed spelling error "lowe-level API" to "low-level API" --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ea1e07f..7c515d0 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm ## Low-level API The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`. -The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h). +The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h). Below is a short example demonstrating how to use the low-level API to tokenize a prompt: From 418aa83b01bc7ea9fdd26546efb9d9899061cc4a Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 02:21:37 +0900 Subject: [PATCH 18/38] Added grammar based sampling --- llama_cpp/llama.py | 36 +- llama_cpp/llama_grammar.py | 1044 +++++++++++++++++------------------- 2 files changed, 537 insertions(+), 543 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 66c76c9..ab99ee5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import sys import uuid import time @@ -23,6 +24,7 @@ import ctypes from . import llama_cpp from .llama_types import * +from .llama_grammar import LlamaGrammar import numpy as np import numpy.typing as npt @@ -223,6 +225,7 @@ class Llama: tensor_split: Optional[List[float]] = None, rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, + grammar: Optional[Union[str, Path]] = None, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) verbose: bool = True, @@ -248,6 +251,7 @@ class Llama: tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. rope_freq_base: Base frequency for rope sampling. rope_freq_scale: Scale factor for rope sampling. + grammar: Path to a BNF grammar file to use for grammar based sampling. verbose: Print verbose output to stderr. Raises: @@ -358,6 +362,12 @@ class Llama: self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx, self._n_vocab), dtype=np.single ) + if grammar is not None: + self.grammar = LlamaGrammar.from_file( + grammar + ) # type: Optional[LlamaGrammar] + else: + self.grammar = None @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -542,8 +552,16 @@ class Llama: ) if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) + + if self.grammar is not None: + llama_cpp.llama_sample_grammar( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + grammar=self.grammar.grammar, + ) + if temp.value == 0.0: - return llama_cpp.llama_sample_token_greedy( + id = llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) @@ -555,7 +573,7 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat( + id = llama_cpp.llama_sample_token_mirostat( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -570,7 +588,7 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat_v2( + id = llama_cpp.llama_sample_token_mirostat_v2( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -607,10 +625,17 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token( + id = llama_cpp.llama_sample_token( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + if self.grammar is not None: + llama_cpp.llama_grammar_accept_token( + ctx=self.ctx, + grammar=self.grammar.grammar, + token=llama_cpp.ctypes.c_int(id), + ) + return id def sample( self, @@ -1509,6 +1534,9 @@ class Llama: if self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar.grammar) + self.grammar = None def __getstate__(self): return dict( diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 07a120f..06b2b7f 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -3,7 +3,7 @@ import argparse from pathlib import Path import sys -from ctypes import Array, c_int, c_size_t, c_uint32, cast +from ctypes import * # type: ignore from enum import Enum from itertools import islice from typing import ( @@ -16,293 +16,379 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) import llama_cpp +# Type aliases +llama_grammar_element = llama_cpp.llama_grammar_element +llama_grammar_element_p = llama_cpp.llama_grammar_element_p +llama_grammar_p = llama_cpp.llama_grammar_p + +# Type variables +Ptr = TypeVar("Ptr", bound="const_char_p") T = TypeVar("T") U = TypeVar("U") V = TypeVar("V") W = TypeVar("W") -size_t = uint8_t = uint32_t = int -static_cast_uint8_t = ord class Sentinel: - pass + """Used to mark the end of a iterator of std::vector & std::map.""" + + +class LlamaGrammar: + """Keeps reference counts of all the arguments, so that they are not + garbage collected by Python.""" + + def __init__( + self, + parsed_grammar: "parse_state", + ) -> None: + grammar_rules = ( + parsed_grammar.c_rules() + ) # type: std.vector[std.vector[llama_grammar_element]] + + # Step 1: Convert each list to llama_grammar_element array and get pointer + self.element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in grammar_rules + ] # type: List[Array[llama_grammar_element]] + + # Step 2: Get pointer of each array + self.element_array_pointers = [ + cast(subarray, llama_grammar_element_p) + for subarray in self.element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 3: Make array of these pointers and get its pointer + self.rules = ( + llama_grammar_element_p * len(self.element_array_pointers) + )(*self.element_array_pointers) + + self.n_rules = c_size_t(grammar_rules.size()) + self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) + self.grammar = self.init_grammar() + + @classmethod + def from_string(cls, grammar: str) -> "LlamaGrammar": + parsed_grammar = parse(const_char_p(grammar)) # type: parse_state + if parsed_grammar.rules.empty(): + raise ValueError( + f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" + ) + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) + return cls(parsed_grammar) + + @classmethod + def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": + try: + with open(file) as f: + grammar = f.read() + except Exception as err: + raise Exception( + f"{cls.from_file.__name__}: error reading grammar file: {err}" + ) + + if grammar: + return cls.from_string(grammar) + + raise ValueError( + f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" + ) + + def init_grammar(self) -> llama_grammar_p: + return llama_cpp.llama_grammar_init( + self.rules, self.n_rules, self.start_rule_index + ) class const_char_p: - """C++ implementation of const char*.""" + """C++ implementation of const char *.""" - def __init__(self, value: Union[str, "const_char_p"]): + def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): if isinstance(value, const_char_p): # We're copying an existing const_char_p self.value = value.value - self.pos = value.pos + self.pos = value.pos + (move or 0) return # We're creating a new const_char_p self.value = value - self.pos = 0 + self.pos = move or 0 def __str__(self) -> str: + assert self.value is not None, "null pointer" return self.value[self.pos :] - def __add__(self, increment: int) -> "const_char_p": - # To avoid side effects, we create a new const_char_p object - new = self.__class__(self.value) - new.pos = self.pos + increment - return new + def __getitem__(self, index: int) -> str: + value = str(self) + return value[index] if index < len(value) else "" - def __sub__(self, decrement: int) -> "const_char_p": - # To avoid side effects, we create a new const_char_p object - new = self.__class__(self.value) - new.pos = self.pos - decrement - return new + @overload + def __add__(self: Ptr, other: int) -> Ptr: + ... - def __lt__(self, other: "const_char_p") -> bool: - return self.pos < other.pos and self.value == other.value + @overload + def __add__(self: Ptr, other: Ptr) -> int: + ... - def __gt__(self, other: "const_char_p") -> bool: - return self.pos > other.pos and self.value == other.value - - def __eq__(self, other: "const_char_p") -> bool: - return self.pos == other.pos and self.value == other.value - - def add(self, other: "const_char_p") -> int: - if self.value != other.value: - raise ValueError("Can't add pointers to different strings") - return self.pos + other.pos - - def sub(self, other: "const_char_p") -> int: - if self.value != other.value: - raise ValueError("Can't subtract pointers to different strings") - return self.pos - other.pos - - def plus_plus(self) -> None: - self.pos += 1 - - def minus_minus(self) -> None: - self.pos -= 1 - - @property - def derefer(self) -> Optional[str]: - if self.pos >= len(self.value): - # We've reached the end of the string - return None - - return self.value[self.pos] - - -class std__vector(Generic[T], List[T]): - """C++ implementation of std::vector.""" - - class iterator: - def __init__(self, vector: "std__vector[T]", index: int): - self._vector = vector - self._index = index - self._version = vector._version - - def _check_version(self): - if self._version != self._vector._version: - raise RuntimeError("Iterator used after vector was modified.") - - def __iter__(self): - return self - - def __next__(self) -> T: - self._check_version() - if self._index >= self._vector.size(): - raise StopIteration - value = self._vector[self._index] - self._index += 1 - return value - - def __add__(self, value: int) -> "std__vector[T].iterator": - return self.__class__(self._vector, self._index + value) - - def __sub__(self, value: int) -> "std__vector[T].iterator": - return self.__class__(self._vector, self._index - value) - - def __init__(self): - self._version = 0 - - def modify(self): - # This is a bit of a hack to make sure iterators are invalidated - self._version += 1 - - def push_back(self, value: T) -> None: - self.modify() - self.append(value) - - def pop_back(self) -> None: - self.modify() - if not self.empty(): - self.pop() - - def back(self) -> T: - return self[-1] - - def size(self) -> int: - return len(self) - - # def clear(self) -> None: - # super().clear() - - def empty(self) -> bool: - return self.size() == 0 - - def data(self) -> "std__vector[T]": - return self - - def resize( - self, - new_size: int, - fill_value_factory: Optional[Callable[[], T]] = None, - ) -> None: - if new_size > self.size(): - if fill_value_factory is None: - raise ValueError( - "A fill value factory function must be provided." - ) - self.reserve(new_size, fill_value_factory) - elif new_size < self.size(): - self[:] = self[:new_size] - - def reserve( - self, capacity: int, fill_value_factory: Callable[[], T] - ) -> None: - if capacity > self.size(): - fill_value = fill_value_factory() - self.extend([fill_value] * (capacity - self.size())) - - def front(self) -> T: - if not self.empty(): - return self[0] - else: - raise IndexError("Vector is empty.") - - def assign(self, count: int, value: T) -> None: - self.clear() - self.extend([value] * count) - - def insert( - self, - pos: "std__vector[T].iterator", - first: "std__vector[T].iterator", - last: "std__vector[T].iterator", - ) -> None: - self[pos._index : pos._index] = list( - islice(first._vector, first._index, last._index) + def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos + other) + if isinstance(other, int) + else self.pos + other.pos ) - def begin(self) -> "std__vector[T].iterator": - return self.iterator(self, 0) + @overload + def __sub__(self: Ptr, other: int) -> Ptr: + ... - def end(self) -> "std__vector[T].iterator": - return self.iterator(self, self.size()) + @overload + def __sub__(self: Ptr, other: Ptr) -> int: + ... + + def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos - other) + if isinstance(other, int) + else self.pos - other.pos + ) + + def __eq__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos == other.pos + + def __lt__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos < other.pos + + def __gt__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos > other.pos -class std__map(Generic[T, U], OrderedDict[T, U]): - """C++ implementation of std::map.""" - - class iterator(Generic[V, W]): - def __init__(self, _map: "std__map[T, U]", key: Union[T, Sentinel]): - self._map = _map - self.iter = iter(_map) - self.key = key - self._advance() - - def _sanitize_key(self) -> T: - if isinstance(self.key, Sentinel): - raise StopIteration - return self.key - - def _advance(self) -> None: - try: - while next(self.iter) != self.key: - pass - except StopIteration: - self.key = Sentinel() - - def __next__(self) -> Tuple[T, U]: - key = self._sanitize_key() - if key in self._map: - value = self._map[key] - self._advance() - return key, value - else: - raise StopIteration - - def get(self) -> Tuple[T, U]: - key = self._sanitize_key() - return key, self._map[key] - - @property - def first(self) -> T: - return self._sanitize_key() - - @property - def second(self) -> U: - return self._map[self._sanitize_key()] - - def insert( - self, key: T, value: U - ) -> Tuple["std__map[T, U].iterator[T, U]", bool]: - if key in self: - return self.iterator(self, key), False - else: - self[key] = value - return self.iterator(self, key), True - - def find(self, key: T) -> "std__map[T, U].iterator[T, U]": - if key in self: - return self.iterator(self, key) - else: - return self.end() - - def at(self, key: T) -> U: - if key in self: - return self[key] - else: - raise KeyError("The provided key is not found in the map.") - - def erase(self, iterator: "std__map[T, U].iterator[T, U]") -> None: - key = iterator.first - if key in self: - del self[key] - - def size(self) -> int: - return len(self) - - def empty(self) -> bool: - return self.size() == 0 - - def lower_bound(self, key: T) -> "std__map[T, U].iterator[T, U]": - try: - keys = sorted(list(self.keys())) # type: ignore - for k in keys: - if k >= key: - return self.iterator(self, k) - raise ValueError( - "No key found that is not less than the input key" - ) - except TypeError: - raise TypeError("Keys of type T cannot be sorted.") - - def begin(self) -> "std__map[T, U].iterator[T, U]": - return self.iterator(self, next(iter(self))) - - def end(self) -> "std__map[T, U].iterator[T, U]": - return self.iterator(self, Sentinel()) - - -class std__string(str): - def __new__(cls, ptr: const_char_p, length: Optional[int] = None): +class std: + @staticmethod + def string(ptr: const_char_p, length: Optional[int] = None) -> str: + """C++ implementation of std::string constructor.""" + value = str(ptr) if length is not None: - return super().__new__(cls, str(ptr)[:length]) - return super().__new__(cls, str(ptr)) + value = value[:length] + return value + + class vector(Generic[T], List[T]): + """C++ implementation of std::vector.""" + + class iterator: + def __init__(self, vector: "std.vector[T]", index: int): + self._vector = vector + self._index = index + self._version = vector._version + + def _check_version(self): + if self._version != self._vector._version: + raise RuntimeError( + "Iterator used after vector was modified." + ) + + def __iter__(self): + return self + + def __next__(self) -> T: + self._check_version() + if self._index >= self._vector.size(): + raise StopIteration + value = self._vector[self._index] + self._index += 1 + return value + + def __add__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index + value) + + def __sub__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index - value) + + def __init__(self): + self._version = 0 + + def modify(self): + # This is a bit of a hack to make sure iterators are invalidated + self._version += 1 + + def push_back(self, value: T) -> None: + self.modify() + self.append(value) + + def pop_back(self) -> None: + self.modify() + if not self.empty(): + self.pop() + + def back(self) -> T: + return self[-1] + + def size(self) -> int: + return len(self) + + def clear(self) -> None: + self.modify() + super().clear() + + def empty(self) -> bool: + return self.size() == 0 + + def data(self) -> "std.vector[T]": + return self + + def resize( + self, + new_size: int, + fill_value_factory: Optional[Callable[[], T]] = None, + ) -> None: + if new_size > self.size(): + if fill_value_factory is None: + raise ValueError( + "A fill value factory function must be provided." + ) + self.reserve(new_size, fill_value_factory) + elif new_size < self.size(): + self[:] = self[:new_size] + + def reserve( + self, capacity: int, fill_value_factory: Callable[[], T] + ) -> None: + if capacity > self.size(): + fill_value = fill_value_factory() + self.extend([fill_value] * (capacity - self.size())) + + def front(self) -> T: + if not self.empty(): + return self[0] + else: + raise IndexError("Vector is empty.") + + def assign(self, count: int, value: T) -> None: + self.clear() + self.extend([value] * count) + + def insert( + self, + pos: "std.vector[T].iterator", + first: "std.vector[T].iterator", + last: "std.vector[T].iterator", + ) -> None: + self[pos._index : pos._index] = list( + islice(first._vector, first._index, last._index) + ) + + def begin(self) -> "std.vector[T].iterator": + return self.iterator(self, 0) + + def end(self) -> "std.vector[T].iterator": + return self.iterator(self, self.size()) + + class map(Generic[T, U], OrderedDict[T, U]): + """C++ implementation of std::map.""" + + class iterator(Generic[V, W]): + def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): + self._map = _map + self.iter = iter(_map) + self.key = key + self._advance() + + def _sanitize_key(self) -> T: + if isinstance(self.key, Sentinel): + raise StopIteration + return self.key + + def _advance(self) -> None: + try: + while next(self.iter) != self.key: + pass + except StopIteration: + self.key = Sentinel() + + def __next__(self) -> Tuple[T, U]: + key = self._sanitize_key() + if key in self._map: + value = self._map[key] + self._advance() + return key, value + else: + raise StopIteration + + def get(self) -> Tuple[T, U]: + key = self._sanitize_key() + return key, self._map[key] + + @property + def first(self) -> T: + return self._sanitize_key() + + @property + def second(self) -> U: + return self._map[self._sanitize_key()] + + def insert( + self, key: T, value: U + ) -> Tuple["std.map[T, U].iterator[T, U]", bool]: + if key in self: + return self.iterator(self, key), False + else: + self[key] = value + return self.iterator(self, key), True + + def find(self, key: T) -> "std.map[T, U].iterator[T, U]": + if key in self: + return self.iterator(self, key) + else: + return self.end() + + def at(self, key: T) -> U: + if key in self: + return self[key] + else: + raise KeyError("The provided key is not found in the map.") + + def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: + key = iterator.first + if key in self: + del self[key] + + def size(self) -> int: + return len(self) + + def empty(self) -> bool: + return self.size() == 0 + + def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": + try: + keys = sorted(list(self.keys())) # type: ignore + for k in keys: + if k >= key: + return self.iterator(self, k) + raise ValueError( + "No key found that is not less than the input key" + ) + except TypeError: + raise TypeError("Keys of type T cannot be sorted.") + + def begin(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, next(iter(self))) + + def end(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, Sentinel()) # // grammar element type @@ -343,28 +429,6 @@ class llama_gretype(Enum): LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) -# typedef struct llama_grammar_element { -# enum llama_gretype type; -# uint32_t value; // Unicode code point or rule ID -# } llama_grammar_element; - - -# class llama_grammar_element(Structure): -# _fields_ = [ -# ("type", c_int), -# ("value", c_uint32), -# ] - - -class llama_grammar_element: - def __init__(self, type: llama_gretype, value: uint32_t): - self.type = type - self.value = value # Unicode code point or rule ID - - def __repr__(self): # debug - return f"llama_grammar_element({self.type}, {self.value})" - - # struct parse_state { # std::map symbol_ids; # std::vector> rules; @@ -372,10 +436,10 @@ class llama_grammar_element: # }; class parse_state: def __init__(self): - self.symbol_ids: std__map[str, uint32_t] = std__map() - self.rules: std__vector[ - std__vector[llama_grammar_element] - ] = std__vector() + self.symbol_ids: std.map[str, int] = std.map() + self.rules: std.vector[ + std.vector[llama_grammar_element] + ] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; @@ -384,27 +448,30 @@ class parse_state: # } # return ret; # } - def c_rules(self) -> std__vector[std__vector[llama_grammar_element]]: + def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: ret = ( - std__vector() - ) # type: std__vector[std__vector[llama_grammar_element]] + std.vector() + ) # type: std.vector[std.vector[llama_grammar_element]] for rule in self.rules: ret.push_back(rule.data()) return ret + def __repr__(self) -> str: + return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + # struct llama_grammar { # const std::vector> rules; # std::vector> stacks; # }; -class llama_grammar: - def __init__( - self, - rules: std__vector[std__vector[llama_grammar_element]], - stacks: std__vector[std__vector[llama_grammar_element]], - ): - self.rules = rules - self.stacks = stacks +# class llama_grammar: +# def __init__( +# self, +# rules: std.vector[std.vector[llama_grammar_element]], +# stacks: std.vector[std.vector[llama_grammar_element]], +# ): +# self.rules = rules +# self.stacks = stacks # uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { @@ -412,9 +479,9 @@ class llama_grammar: # auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); # return result.first->second; # } -def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: - next_id = uint32_t(state.symbol_ids.size()) # type: uint32_t - result = state.symbol_ids.insert(str(std__string(src, len)), next_id) +def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: + next_id = state.symbol_ids.size() # type: int + result = state.symbol_ids.insert(std.string(src, len), next_id) return result[0].second # type: ignore @@ -423,8 +490,8 @@ def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: # state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; # return next_id; # } -def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: - next_id = state.symbol_ids.size() # type: uint32_t +def generate_symbol_id(state: parse_state, base_name: str) -> int: + next_id = state.symbol_ids.size() # type: int state.symbol_ids[base_name + "_" + str(next_id)] = next_id return next_id @@ -440,12 +507,13 @@ def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: # } def add_rule( state: parse_state, - rule_id: uint32_t, - rule: std__vector[llama_grammar_element], + rule_id: int, + rule: std.vector[llama_grammar_element], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( - rule_id + 1, fill_value_factory=std__vector[llama_grammar_element] + rule_id + 1, + fill_value_factory=std.vector[llama_grammar_element], ) state.rules[rule_id] = rule @@ -464,19 +532,19 @@ def add_rule( # } # return std::make_pair(value, pos); # } -def decode_utf8(src: const_char_p) -> Tuple[uint32_t, const_char_p]: +def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: """Decodes a UTF-8 character from the source string.""" lookup = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4) - first_byte = static_cast_uint8_t(src.derefer or "") # type: uint8_t - highbits = first_byte >> 4 # type: uint8_t + first_byte = ord(src[0]) # type: int + highbits = first_byte >> 4 # type: int len = lookup[highbits] # type: int - mask = (1 << (8 - len)) - 1 # type: uint8_t - value = first_byte & mask # type: uint32_t + mask = (1 << (8 - len)) - 1 # type: int + value = first_byte & mask # type: int end = src + len # type: const_char_p # may overrun! pos = src + 1 # type: const_char_p - while pos < end and pos.derefer: - value = (value << 6) + (static_cast_uint8_t(src.derefer or "") & 0x3F) - pos.plus_plus() + while pos < end and pos[0]: + value = (value << 6) + (ord(pos[0]) & 0x3F) + pos += 1 return value, pos @@ -511,22 +579,22 @@ def is_word_char(c: str) -> bool: # } # return std::make_pair(value, pos); # } -def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: +def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: pos = const_char_p(src) # type: const_char_p end = src + size # type: const_char_p - value = 0 # type: uint32_t - while pos < end and pos.derefer: + value = 0 # type: int + while pos < end and pos[0]: value <<= 4 - c = pos.derefer # type: str + c = pos[0] # type: str if "a" <= c <= "f": - value += static_cast_uint8_t(c) - static_cast_uint8_t("a") + 10 + value += ord(c) - ord("a") + 10 elif "A" <= c <= "F": - value += static_cast_uint8_t(c) - static_cast_uint8_t("A") + 10 + value += ord(c) - ord("A") + 10 elif "0" <= c <= "9": - value += static_cast_uint8_t(c) - static_cast_uint8_t("0") + value += ord(c) - ord("0") else: break - pos.plus_plus() + pos += 1 if pos != end: raise RuntimeError( "expecting " + str(size) + " hex chars at " + str(src) @@ -556,26 +624,26 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: # } # throw std::runtime_error("unexpected end of input"); # } -def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: - if src.derefer == "\\": - switch = (src + 1).derefer # type: Optional[str] - if switch == "x": +def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: + if src[0] == "\\": + case = src[1] # type: str + if case == "x": return parse_hex(src + 2, 2) - elif switch == "u": + elif case == "u": return parse_hex(src + 2, 4) - elif switch == "U": + elif case == "U": return parse_hex(src + 2, 8) - elif switch == "t": - return (static_cast_uint8_t("\t"), src + 2) # implicit cast - elif switch == "r": - return (static_cast_uint8_t("\r"), src + 2) # implicit cast - elif switch == "n": - return (static_cast_uint8_t("\n"), src + 2) # implicit cast - elif switch in ("\\", '"', "[", "]"): - return (static_cast_uint8_t(switch), src + 2) # implicit cast + elif case == "t": + return (ord("\t"), src + 2) # implicit cast + elif case == "r": + return (ord("\r"), src + 2) # implicit cast + elif case == "n": + return (ord("\n"), src + 2) # implicit cast + elif case in ("\\", '"', "[", "]"): + return (ord(case), src + 2) # implicit cast else: raise RuntimeError("unknown escape at " + str(src)) - elif src.derefer: + elif src[0]: return decode_utf8(src) else: raise RuntimeError("unexpected end of input") @@ -593,8 +661,8 @@ def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: # } def parse_name(src: const_char_p) -> const_char_p: pos = const_char_p(src) # type: const_char_p - while is_word_char(pos.derefer or ""): - pos.plus_plus() + while is_word_char(pos[0]): + pos += 1 if pos == src: raise RuntimeError("expecting name at " + str(src)) return pos @@ -615,18 +683,15 @@ def parse_name(src: const_char_p) -> const_char_p: # return pos; # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: - # Using a copy of `src` to avoid side effects - pos = const_char_p(src) - - while pos.derefer in (" ", "\t", "#") or ( - newline_ok and pos.derefer in ("\r", "\n") + pos = const_char_p(src) # type: const_char_p + while pos[0] in (" ", "\t", "#") or ( + newline_ok and pos[0] in ("\r", "\n") ): - if pos.derefer == "#": - while pos.derefer is not None and pos.derefer not in ("\r", "\n"): - pos.plus_plus() + if pos[0] == "#": + while pos[0] is not None and pos[0] not in ("\r", "\n"): + pos += 1 else: - pos.plus_plus() - + pos += 1 return pos @@ -640,15 +705,15 @@ def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, - out_elements: std__vector[llama_grammar_element], + out_elements: std.vector[llama_grammar_element], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); # const char * pos = src; - last_sym_start = out_elements.size() # type: size_t + last_sym_start = out_elements.size() # type: int pos = const_char_p(src) # type: const_char_p # while (*pos) { - while pos.derefer: + while pos[0]: # if (*pos == '"') { // literal string # pos++; # last_sym_start = out_elements.size(); @@ -658,25 +723,23 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); # } # pos = parse_space(pos + 1, is_nested); - if pos.derefer == '"': # literal string - pos.plus_plus() + if pos[0] == '"': # literal string + pos += 1 last_sym_start = out_elements.size() - while pos.derefer != '"': - char_pair = parse_char( - pos - ) # type: Tuple[uint32_t, const_char_p] + while pos[0] != '"': + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0] + llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0] ) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) # pos++; # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - elif pos.derefer == "[": # char range(s) - pos.plus_plus() + elif pos[0] == "[": # char range(s) + pos += 1 start_type = ( llama_gretype.LLAMA_GRETYPE_CHAR ) # type: llama_gretype @@ -685,8 +748,8 @@ def parse_sequence( # start_type = LLAMA_GRETYPE_CHAR_NOT; # } # last_sym_start = out_elements.size(); - if pos.derefer == "^": - pos.plus_plus() + if pos[0] == "^": + pos += 1 start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT last_sym_start = out_elements.size() # while (*pos != ']') { @@ -696,10 +759,8 @@ def parse_sequence( # ? LLAMA_GRETYPE_CHAR_ALT # : start_type; # out_elements.push_back({type, char_pair.first}); - while pos.derefer != "]": - char_pair = parse_char( - pos - ) # type: Tuple[uint32_t, const_char_p] + while pos[0] != "]": + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] type = ( llama_gretype.LLAMA_GRETYPE_CHAR_ALT @@ -707,7 +768,7 @@ def parse_sequence( else start_type ) # type: llama_gretype out_elements.push_back( - llama_grammar_element(type, char_pair[0]) + llama_grammar_element(type.value, char_pair[0]) ) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); @@ -715,14 +776,14 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); # } # } - if pos.derefer == "-" and (pos + 1).derefer != "]": + if pos[0] == "-" and pos[1] != "]": endchar_pair = parse_char( pos + 1 - ) # type: Tuple[uint32_t, const_char_p] + ) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, endchar_pair[0], ) ) @@ -734,16 +795,16 @@ def parse_sequence( # pos = parse_space(name_end, is_nested); # last_sym_start = out_elements.size(); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - elif is_word_char(pos.derefer): # rule reference + elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p ref_rule_id = get_symbol_id( - state, pos, name_end.sub(pos) - ) # type: uint32_t + state, pos, name_end - pos + ) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id ) ) # } else if (*pos == '(') { // grouping @@ -758,26 +819,26 @@ def parse_sequence( # throw std::runtime_error(std::string("expecting ')' at ") + pos); # } # pos = parse_space(pos + 1, is_nested); - elif pos.derefer == "(": # grouping + elif pos[0] == "(": # grouping + # parse nested alternates into synthesized rule pos = parse_space(pos + 1, True) - sub_rule_id = generate_symbol_id( - state, rule_name - ) # type: uint32_t + sub_rule_id = generate_symbol_id(state, rule_name) # type: int pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) last_sym_start = out_elements.size() + # output reference to synthesized rule out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) - if pos.derefer != ")": + if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator # if (last_sym_start == out_elements.size()) { # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); # } - elif pos.derefer in ("*", "+", "?"): # repetition operator + elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): raise RuntimeError( "expecting preceding item to */+/? at " + str(pos) @@ -792,10 +853,10 @@ def parse_sequence( # // add preceding symbol to generated rule # sub_rule.insert( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - sub_rule_id = generate_symbol_id( - state, rule_name - ) # type: uint32_t - sub_rule = std__vector[llama_grammar_element]() + sub_rule_id = generate_symbol_id(state, rule_name) # type: int + sub_rule = std.vector[ + llama_grammar_element + ]() # type: std.vector[llama_grammar_element] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, @@ -807,14 +868,14 @@ def parse_sequence( # } # // mark start of alternate def # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if pos.derefer in ("*", "+"): + if pos[0] in ("*", "+"): sub_rule.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) ) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) @@ -827,20 +888,22 @@ def parse_sequence( # out_elements.resize(last_sym_start); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); # pos = parse_space(pos + 1, is_nested); - if pos.derefer == "+": + if pos[0] == "+": + # add preceding symbol as alternate only for '+' (otherwise empty) sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end(), ) sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) ) add_rule(state, sub_rule_id, sub_rule) + # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) pos = parse_space(pos + 1, is_nested) @@ -876,20 +939,22 @@ def parse_alternates( state: parse_state, src: const_char_p, rule_name: str, - rule_id: uint32_t, + rule_id: int, is_nested: bool, ) -> const_char_p: - rule = std__vector() # type: std__vector[llama_grammar_element] + rule = std.vector() # type: std.vector[llama_grammar_element] pos = parse_sequence( state, src, rule_name, rule, is_nested ) # type: const_char_p - while pos.derefer == "|": + while pos[0] == "|": rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) ) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back(llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0)) + rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) + ) add_rule(state, rule_id, rule) return pos @@ -921,15 +986,11 @@ def parse_alternates( def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: name_end = parse_name(src) # type: const_char_p pos = parse_space(name_end, False) # type: const_char_p - name_len = name_end.sub(src) # type: size_t - rule_id = get_symbol_id(state, src, name_len) # type: uint32_t - name = std__string(src, name_len) # type: std__string + name_len = name_end - src # type: int + rule_id = get_symbol_id(state, src, name_len) # type: int + name = std.string(src, name_len) # type: str - if not ( - pos.derefer == ":" - and (pos + 1).derefer == ":" - and (pos + 2).derefer == "=" - ): + if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p @@ -937,11 +998,11 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: state, pos, name, rule_id, False ) # type: const_char_p - if pos.derefer == "\r": - pos += 2 if (pos + 1).derefer == "\n" else 1 - elif pos.derefer == "\n": - pos.plus_plus() - elif pos.derefer: + if pos[0] == "\r": + pos += 2 if pos[1] == "\n" else 1 + elif pos[0] == "\n": + pos += 1 + elif pos[0]: raise RuntimeError("expecting newline or end at " + str(pos)) return parse_space(pos, True) @@ -963,7 +1024,7 @@ def parse(src: const_char_p) -> parse_state: try: state = parse_state() # type: parse_state pos = parse_space(src, True) # type: const_char_p - while pos.derefer: + while pos[0]: pos = parse_rule(state, pos) return state except Exception as err: @@ -979,7 +1040,7 @@ def parse(src: const_char_p) -> parse_state: # fprintf(file, "", c); # } # } -def print_grammar_char(file: TextIO, c: uint32_t) -> None: +def print_grammar_char(file: TextIO, c: int) -> None: if 0x20 <= c and c <= 0x7F: file.write(chr(c)) else: @@ -998,10 +1059,10 @@ def print_grammar_char(file: TextIO, c: uint32_t) -> None: # } def is_char_element(elem: llama_grammar_element) -> bool: return elem.type in ( - llama_gretype.LLAMA_GRETYPE_CHAR, - llama_gretype.LLAMA_GRETYPE_CHAR_NOT, - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR.value, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, ) @@ -1012,16 +1073,19 @@ def is_char_element(elem: llama_grammar_element) -> bool: # const std::map & symbol_id_names) { def print_rule( file: TextIO, - rule_id: uint32_t, - rule: std__vector[llama_grammar_element], - symbol_id_names: std__map[uint32_t, str], + rule_id: int, + rule: std.vector[llama_grammar_element], + symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { # throw std::runtime_error( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: + if ( + rule.empty() + or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value + ): raise RuntimeError( "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) @@ -1067,22 +1131,22 @@ def print_rule( # break; # } for i, elem in enumerate(rule[:-1]): - switch = elem.type # type: llama_gretype - if switch == llama_gretype.LLAMA_GRETYPE_END: + case = elem.type # type: int + if case == llama_gretype.LLAMA_GRETYPE_END.value: raise RuntimeError( "unexpected end of rule: " + str(rule_id) + "," + str(i) ) - elif switch == llama_gretype.LLAMA_GRETYPE_ALT: + elif case == llama_gretype.LLAMA_GRETYPE_ALT.value: print("| ", file=file, end="") - elif switch == llama_gretype.LLAMA_GRETYPE_RULE_REF: + elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: print("[", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_NOT: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: print("[^", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " @@ -1092,7 +1156,7 @@ def print_rule( ) print("-", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " @@ -1110,8 +1174,8 @@ def print_rule( # fprintf(file, "] "); if is_char_element(elem): if rule[i + 1].type in ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, ): pass else: @@ -1142,7 +1206,7 @@ def print_rule( # } def print_grammar(file: TextIO, state: parse_state) -> None: try: - symbol_id_names = std__map() # type: std__map[uint32_t, str] + symbol_id_names = std.map() # type: std.map[int, str] for kv in state.symbol_ids.items(): symbol_id_names[kv[1]] = kv[0] @@ -1155,117 +1219,25 @@ def print_grammar(file: TextIO, state: parse_state) -> None: ) -def convert_to_rules( - llama_grammar_elements: std__vector[std__vector[llama_grammar_element]], -) -> Array[llama_cpp.llama_grammar_element_p]: - """Make an Array object that is used for `llama_grammer_init`""" +# def convert_to_rules( +# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]], +# ) -> Array[llama_grammar_element_p]: +# """Make an Array object that is used for `llama_grammer_init`""" - # Step 1: Convert to c_llama_grammar_element - llama_grammar_element_p_p = ( - [] - ) # type: List[List[llama_cpp.llama_grammar_element]] - for subvector in llama_grammar_elements: - llama_grammar_element_p_p.append([]) - for elem in subvector: - c_llama_grammar_element = llama_cpp.llama_grammar_element() - c_llama_grammar_element.type = c_int(elem.type.value) - c_llama_grammar_element.value = c_uint32(elem.value) - llama_grammar_element_p_p[-1].append(c_llama_grammar_element) +# # Step 1: Convert each list to llama_grammar_element array and get pointer +# element_arrays = [ +# (llama_grammar_element * len(subvector))(*subvector) +# for subvector in llama_grammar_elements +# ] # type: List[Array[llama_grammar_element]] - # Step 2: Convert each list to llama_grammar_element array and get pointer - element_arrays = [ - (llama_cpp.llama_grammar_element * len(sublist))(*sublist) - for sublist in llama_grammar_element_p_p - ] # type: List[Array[llama_cpp.llama_grammar_element]] +# # Step 2: Get pointer of each array +# element_array_pointers = [ +# cast(subarray, llama_grammar_element_p) for subarray in element_arrays +# ] # type: List[llama_grammar_element_p] - # Step 3: Get pointer of each array - element_array_pointers = [ - cast(sublist, llama_cpp.llama_grammar_element_p) - for sublist in element_arrays - ] # type: List[llama_cpp.llama_grammar_element_p] - - # Step 4: Make array of these pointers and get its pointer - return (llama_cpp.llama_grammar_element_p * len(element_array_pointers))( - *element_array_pointers - ) - - -def parse_grammar_init_args( - bnf: str, -) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: - """Parse a GBNF string and return tuple of `grammar rules` and `root symbol id`""" - parsed_grammar = parse(const_char_p(bnf)) # type: parse_state - if parsed_grammar.rules.empty(): - raise Exception( - f"{parse_grammar_init_args.__name__}: error parsing grammar file: parsed_grammar.rules is empty" - ) - print(f"{parse_grammar_init_args.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stdout, parsed_grammar) - print(file=sys.stderr) - grammar_rules = ( - parsed_grammar.c_rules() - ) # type: std__vector[std__vector[llama_grammar_element]] - return ( - convert_to_rules(grammar_rules), - c_size_t(grammar_rules.size()), - c_size_t(parsed_grammar.symbol_ids.at("root")), - ) - - -def parse_grammar_init_args_from_file( - bnf_path: Union[str, Path] -) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: - """Parse a GBNF file and return tuple of `grammar rules` and `root symbol id`""" - try: - with open(bnf_path) as f: - params_grammer = f.read() - except Exception as err: - raise Exception( - f"{parse_grammar_init_args_from_file.__name__}: error reading grammar file: {err}" - ) - - if params_grammer: - return parse_grammar_init_args(params_grammer) - - raise Exception( - f"{parse_grammar_init_args_from_file.__name__}: error parsing grammar file: params_grammer is empty" - ) - - -# def get_grammar_p(bnf: str) -> llama_cpp.llama_grammar_p: -# """Parse a GBNF string and return pointer to `llama_grammar`""" - -# grammar_rules, root_symbol_id = parse_rules(bnf) - -# grammar_element_p_p = convert_to_double_ptr( -# grammar_rules -# ) # type: llama_cpp.llama_grammar_element_p_p - -# c_llama_grammar_p = llama_cpp.llama_grammar_init( -# grammar_element_p_p, -# c_size_t(grammar_rules.size()), -# c_size_t(root_symbol_id), -# ) # type: llama_cpp.llama_grammar_p -# return c_llama_grammar_p - - -# def get_grammar_p_from_file( -# bnf_path: Union[str, Path] -# ) -> llama_cpp.llama_grammar_p: -# """Parse a GBNF file and return pointer to `llama_grammar`""" -# try: -# with open(bnf_path) as f: -# params_grammer = f.read() -# except Exception as err: -# raise Exception( -# f"{get_grammar_p_from_file.__name__}: error reading grammar file: {err}" -# ) - -# if params_grammer: -# return get_grammar_p(params_grammer) - -# raise Exception( -# f"{get_grammar_p_from_file.__name__}: error parsing grammar file: params_grammer is empty" +# # Step 3: Make array of these pointers and get its pointer +# return (llama_grammar_element_p * len(element_array_pointers))( +# *element_array_pointers # ) @@ -1282,14 +1254,8 @@ if __name__ == "__main__": ) args = parser.parse_args() - rules, n_rules, start_rule_index = parse_grammar_init_args_from_file( - args.grammar - ) - llama_grammar_p = llama_cpp.llama_grammar_init( - rules, - n_rules, - start_rule_index, - ) # type: llama_cpp.llama_grammar_p + llama_grammar = LlamaGrammar.from_file(Path(args.grammar)) + llama_grammar_ptr = llama_grammar.init_grammar() # ----- USAGE: # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) From b07713cb9f26d97e0d5f740e722e60ad0d7e9ddf Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 15:16:25 +0900 Subject: [PATCH 19/38] reset grammar for every generation --- llama_cpp/llama.py | 9 ++- llama_cpp/llama_grammar.py | 125 +++++++++++-------------------------- 2 files changed, 39 insertions(+), 95 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ab99ee5..9328c5b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -364,7 +364,7 @@ class Llama: ) if grammar is not None: self.grammar = LlamaGrammar.from_file( - grammar + grammar, verbose=verbose ) # type: Optional[LlamaGrammar] else: self.grammar = None @@ -723,7 +723,6 @@ class Llama: The generated tokens. """ assert self.ctx is not None - if reset and len(self._input_ids) > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -741,6 +740,9 @@ class Llama: if reset: self.reset() + if self.grammar is not None: + self.grammar.reset() + while True: self.eval(tokens) token = self.sample( @@ -1534,9 +1536,6 @@ class Llama: if self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar.grammar) - self.grammar = None def __getstate__(self): return dict( diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 06b2b7f..5388676 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,6 +1,5 @@ """C++ implementation of the llama grammar parser.""" # flake8: noqa -import argparse from pathlib import Path import sys from ctypes import * # type: ignore @@ -19,7 +18,7 @@ from typing import ( overload, ) -import llama_cpp +from . import llama_cpp # Type aliases llama_grammar_element = llama_cpp.llama_grammar_element @@ -41,11 +40,19 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" + + def __del__(self) -> None: + """Free the grammar pointer when the object is deleted.""" + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = None def __init__( self, parsed_grammar: "parse_state", ) -> None: + """Initialize the grammar pointer from the parsed state.""" + self.parsed_grammar = parsed_grammar grammar_rules = ( parsed_grammar.c_rules() ) # type: std.vector[std.vector[llama_grammar_element]] @@ -69,22 +76,25 @@ class LlamaGrammar: self.n_rules = c_size_t(grammar_rules.size()) self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self.grammar = self.init_grammar() + self._grammar = llama_cpp.llama_grammar_init( + self.rules, self.n_rules, self.start_rule_index + ) @classmethod - def from_string(cls, grammar: str) -> "LlamaGrammar": + def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": parsed_grammar = parse(const_char_p(grammar)) # type: parse_state if parsed_grammar.rules.empty(): raise ValueError( f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" ) - print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stdout, parsed_grammar) - print(file=sys.stderr) + if verbose: + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) return cls(parsed_grammar) @classmethod - def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": + def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: with open(file) as f: grammar = f.read() @@ -94,14 +104,27 @@ class LlamaGrammar: ) if grammar: - return cls.from_string(grammar) + return cls.from_string(grammar, verbose=verbose) raise ValueError( f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - def init_grammar(self) -> llama_grammar_p: - return llama_cpp.llama_grammar_init( + @property + def grammar(self) -> llama_grammar_p: + if self._grammar is None: + raise ValueError( + f"{self.__class__.__name__}.grammar: grammar is freed" + ) + return self._grammar + + @grammar.setter + def grammar(self, value: Optional[llama_grammar_p]) -> None: + self._grammar = value + + def reset(self) -> None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = llama_cpp.llama_grammar_init( self.rules, self.n_rules, self.start_rule_index ) @@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) - - -# def convert_to_rules( -# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]], -# ) -> Array[llama_grammar_element_p]: -# """Make an Array object that is used for `llama_grammer_init`""" - -# # Step 1: Convert each list to llama_grammar_element array and get pointer -# element_arrays = [ -# (llama_grammar_element * len(subvector))(*subvector) -# for subvector in llama_grammar_elements -# ] # type: List[Array[llama_grammar_element]] - -# # Step 2: Get pointer of each array -# element_array_pointers = [ -# cast(subarray, llama_grammar_element_p) for subarray in element_arrays -# ] # type: List[llama_grammar_element_p] - -# # Step 3: Make array of these pointers and get its pointer -# return (llama_grammar_element_p * len(element_array_pointers))( -# *element_array_pointers -# ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate C++ parser from GBNF grammar" - ) - parser.add_argument( - "-g", - "--grammar", - type=str, - default="./vendor/llama.cpp/grammars/json.gbnf", - help="path to GBNF grammar file", - ) - - args = parser.parse_args() - llama_grammar = LlamaGrammar.from_file(Path(args.grammar)) - llama_grammar_ptr = llama_grammar.init_grammar() - - # ----- USAGE: - # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) - # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...) - - # ----- SAMPLE OUTPUT: - # main grammar: - # root ::= object - # object ::= [{] ws object_11 [}] ws - # value ::= object | array | string | number | value_6 ws - # array ::= [[] ws array_15 []] ws - # string ::= ["] string_18 ["] ws - # number ::= number_19 number_25 number_29 ws - # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l] - # ws ::= ws_31 - # object_8 ::= string [:] ws value object_10 - # object_9 ::= [,] ws string [:] ws value - # object_10 ::= object_9 object_10 | - # object_11 ::= object_8 | - # array_12 ::= value array_14 - # array_13 ::= [,] ws value - # array_14 ::= array_13 array_14 | - # array_15 ::= array_12 | - # string_16 ::= [^"\] | [\] string_17 - # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] - # string_18 ::= string_16 string_18 | - # number_19 ::= number_20 number_21 - # number_20 ::= [-] | - # number_21 ::= [0-9] | [1-9] number_22 - # number_22 ::= [0-9] number_22 | - # number_23 ::= [.] number_24 - # number_24 ::= [0-9] number_24 | [0-9] - # number_25 ::= number_23 | - # number_26 ::= [eE] number_27 number_28 - # number_27 ::= [-+] | - # number_28 ::= [0-9] number_28 | [0-9] - # number_29 ::= number_26 | - # ws_30 ::= [ ] ws - # ws_31 ::= ws_30 | + ) \ No newline at end of file From 0d7d2031a9401a483293d9e91749ee64a9f64d54 Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 17:02:33 +0900 Subject: [PATCH 20/38] prevent memory access error by llama_grammar_free --- llama_cpp/llama_grammar.py | 253 ++++++++++++++----------------------- 1 file changed, 98 insertions(+), 155 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 5388676..f35f9fa 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -40,7 +40,7 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" - + def __del__(self) -> None: """Free the grammar pointer when the object is deleted.""" if self.grammar is not None: @@ -52,33 +52,12 @@ class LlamaGrammar: parsed_grammar: "parse_state", ) -> None: """Initialize the grammar pointer from the parsed state.""" - self.parsed_grammar = parsed_grammar - grammar_rules = ( + self._grammar_rules = ( parsed_grammar.c_rules() - ) # type: std.vector[std.vector[llama_grammar_element]] - - # Step 1: Convert each list to llama_grammar_element array and get pointer - self.element_arrays = [ - (llama_grammar_element * len(sublist))(*sublist) - for sublist in grammar_rules - ] # type: List[Array[llama_grammar_element]] - - # Step 2: Get pointer of each array - self.element_array_pointers = [ - cast(subarray, llama_grammar_element_p) - for subarray in self.element_arrays - ] # type: List[llama_grammar_element_p] - - # Step 3: Make array of these pointers and get its pointer - self.rules = ( - llama_grammar_element_p * len(self.element_array_pointers) - )(*self.element_array_pointers) - - self.n_rules = c_size_t(grammar_rules.size()) - self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self._grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index - ) + ) # type: std.vector[std.vector[LlamaGrammarElement]] + self._n_rules = self._grammar_rules.size() # type: int + self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int + self.grammar = self.init() @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": @@ -110,23 +89,45 @@ class LlamaGrammar: f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - @property - def grammar(self) -> llama_grammar_p: - if self._grammar is None: - raise ValueError( - f"{self.__class__.__name__}.grammar: grammar is freed" - ) - return self._grammar - - @grammar.setter - def grammar(self, value: Optional[llama_grammar_p]) -> None: - self._grammar = value + def init(self) -> None: + # Step 1: Convert LlamaGrammarElement to llama_grammar_element + self._element_lists = [ + [ + llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) + for elem in subvector + ] + for subvector in self._grammar_rules + ] # type: List[List[llama_grammar_element]] + + # Step 2: Convert each list to llama_grammar_element array and get pointer + self._element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in self._element_lists + ] # type: List[Array[llama_grammar_element]] + + # Step 3: Get pointer of each array + self._element_array_pointers = [ + cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( + *self._element_array_pointers + ) + self.grammar = llama_cpp.llama_grammar_init( + self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) + ) def reset(self) -> None: - llama_cpp.llama_grammar_free(self.grammar) - self.grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index - ) + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.init() + + +class LlamaGrammarElement: + def __init__(self, type: "llama_gretype", value: int): + self.type = type + self.value = value # Unicode code point or rule ID class const_char_p: @@ -182,21 +183,15 @@ class const_char_p: ) def __eq__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos == other.pos def __lt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos < other.pos def __gt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos > other.pos @@ -220,9 +215,7 @@ class std: def _check_version(self): if self._version != self._vector._version: - raise RuntimeError( - "Iterator used after vector was modified." - ) + raise RuntimeError("Iterator used after vector was modified.") def __iter__(self): return self @@ -280,16 +273,12 @@ class std: ) -> None: if new_size > self.size(): if fill_value_factory is None: - raise ValueError( - "A fill value factory function must be provided." - ) + raise ValueError("A fill value factory function must be provided.") self.reserve(new_size, fill_value_factory) elif new_size < self.size(): self[:] = self[:new_size] - def reserve( - self, capacity: int, fill_value_factory: Callable[[], T] - ) -> None: + def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: if capacity > self.size(): fill_value = fill_value_factory() self.extend([fill_value] * (capacity - self.size())) @@ -401,9 +390,7 @@ class std: for k in keys: if k >= key: return self.iterator(self, k) - raise ValueError( - "No key found that is not less than the input key" - ) + raise ValueError("No key found that is not less than the input key") except TypeError: raise TypeError("Keys of type T cannot be sorted.") @@ -460,9 +447,7 @@ class llama_gretype(Enum): class parse_state: def __init__(self): self.symbol_ids: std.map[str, int] = std.map() - self.rules: std.vector[ - std.vector[llama_grammar_element] - ] = std.vector() + self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; @@ -471,16 +456,16 @@ class parse_state: # } # return ret; # } - def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: - ret = ( - std.vector() - ) # type: std.vector[std.vector[llama_grammar_element]] + def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: + ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] for rule in self.rules: ret.push_back(rule.data()) return ret def __repr__(self) -> str: - return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + return ( + f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + ) # struct llama_grammar { @@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int: def add_rule( state: parse_state, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( rule_id + 1, - fill_value_factory=std.vector[llama_grammar_element], + fill_value_factory=std.vector[LlamaGrammarElement], ) state.rules[rule_id] = rule @@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # } def is_word_char(c: str) -> bool: - return ( - ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") - ) + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") # std::pair parse_hex(const char * src, int size) { @@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: break pos += 1 if pos != end: - raise RuntimeError( - "expecting " + str(size) + " hex chars at " + str(src) - ) + raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) return (value, pos) @@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p: # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: pos = const_char_p(src) # type: const_char_p - while pos[0] in (" ", "\t", "#") or ( - newline_ok and pos[0] in ("\r", "\n") - ): + while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): if pos[0] == "#": while pos[0] is not None and pos[0] not in ("\r", "\n"): pos += 1 @@ -728,7 +707,7 @@ def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, - out_elements: std.vector[llama_grammar_element], + out_elements: std.vector[LlamaGrammarElement], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); @@ -753,9 +732,7 @@ def parse_sequence( char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0] - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) @@ -763,9 +740,7 @@ def parse_sequence( # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; elif pos[0] == "[": # char range(s) pos += 1 - start_type = ( - llama_gretype.LLAMA_GRETYPE_CHAR - ) # type: llama_gretype + start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype # if (*pos == '^') { # pos++; # start_type = LLAMA_GRETYPE_CHAR_NOT; @@ -790,9 +765,7 @@ def parse_sequence( if last_sym_start < out_elements.size() else start_type ) # type: llama_gretype - out_elements.push_back( - llama_grammar_element(type.value, char_pair[0]) - ) + out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); # pos = endchar_pair.second; @@ -800,13 +773,11 @@ def parse_sequence( # } # } if pos[0] == "-" and pos[1] != "]": - endchar_pair = parse_char( - pos + 1 - ) # type: Tuple[int, const_char_p] + endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair[0], ) ) @@ -820,15 +791,11 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p - ref_rule_id = get_symbol_id( - state, pos, name_end - pos - ) # type: int + ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) ) # } else if (*pos == '(') { // grouping # // parse nested alternates into synthesized rule @@ -850,9 +817,7 @@ def parse_sequence( last_sym_start = out_elements.size() # output reference to synthesized rule out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) @@ -863,9 +828,7 @@ def parse_sequence( # } elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): - raise RuntimeError( - "expecting preceding item to */+/? at " + str(pos) - ) + raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) # // apply transformation to previous symbol (last_sym_start to end) according to # // rewrite rules: # // S* --> S' ::= S S' | @@ -878,8 +841,8 @@ def parse_sequence( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); sub_rule_id = generate_symbol_id(state, rule_name) # type: int sub_rule = std.vector[ - llama_grammar_element - ]() # type: std.vector[llama_grammar_element] + LlamaGrammarElement + ]() # type: std.vector[LlamaGrammarElement] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, @@ -893,13 +856,11 @@ def parse_sequence( # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if pos[0] in ("*", "+"): sub_rule.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id ) ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) # sub_rule.insert( @@ -918,16 +879,12 @@ def parse_sequence( out_elements.begin() + last_sym_start, out_elements.end(), ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, sub_rule_id, sub_rule) # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) pos = parse_space(pos + 1, is_nested) # } else { @@ -965,19 +922,13 @@ def parse_alternates( rule_id: int, is_nested: bool, ) -> const_char_p: - rule = std.vector() # type: std.vector[llama_grammar_element] - pos = parse_sequence( - state, src, rule_name, rule, is_nested - ) # type: const_char_p + rule = std.vector() # type: std.vector[LlamaGrammarElement] + pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p while pos[0] == "|": - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, rule_id, rule) return pos @@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p - pos = parse_alternates( - state, pos, name, rule_id, False - ) # type: const_char_p + pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p if pos[0] == "\r": pos += 2 if pos[1] == "\n" else 1 @@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None: # default: return false; # } # } -def is_char_element(elem: llama_grammar_element) -> bool: +def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( llama_gretype.LLAMA_GRETYPE_CHAR.value, llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, @@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool: def print_rule( file: TextIO, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { @@ -1105,13 +1054,9 @@ def print_rule( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if ( - rule.empty() - or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value - ): + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1154,22 +1099,20 @@ def print_rule( # break; # } for i, elem in enumerate(rule[:-1]): - case = elem.type # type: int - if case == llama_gretype.LLAMA_GRETYPE_END.value: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) - elif case == llama_gretype.LLAMA_GRETYPE_ALT.value: + case = elem.type # type: llama_gretype + if case is llama_gretype.LLAMA_GRETYPE_END.value: + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) + elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: + elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR: print("[", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: print("[^", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " @@ -1179,7 +1122,7 @@ def print_rule( ) print("-", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " @@ -1239,4 +1182,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) \ No newline at end of file + ) From 4cf2fc7d3d2635190f670eff41f0d1e52462f59c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Aug 2023 20:09:55 +0000 Subject: [PATCH 21/38] Bump mkdocs from 1.5.1 to 1.5.2 Bumps [mkdocs](https://github.com/mkdocs/mkdocs) from 1.5.1 to 1.5.2. - [Release notes](https://github.com/mkdocs/mkdocs/releases) - [Commits](https://github.com/mkdocs/mkdocs/compare/1.5.1...1.5.2) --- updated-dependencies: - dependency-name: mkdocs dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1dcbfe6..932f15f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -744,13 +744,13 @@ files = [ [[package]] name = "mkdocs" -version = "1.5.1" +version = "1.5.2" description = "Project documentation with Markdown." optional = false python-versions = ">=3.7" files = [ - {file = "mkdocs-1.5.1-py3-none-any.whl", hash = "sha256:67e889f8d8ba1fe5decdfc59f5f8f21d6a8925a129339e93dede303bdea03a98"}, - {file = "mkdocs-1.5.1.tar.gz", hash = "sha256:f2f323c62fffdf1b71b84849e39aef56d6852b3f0a5571552bca32cefc650209"}, + {file = "mkdocs-1.5.2-py3-none-any.whl", hash = "sha256:60a62538519c2e96fe8426654a67ee177350451616118a41596ae7c876bb7eac"}, + {file = "mkdocs-1.5.2.tar.gz", hash = "sha256:70d0da09c26cff288852471be03c23f0f521fc15cf16ac89c7a3bfb9ae8d24f9"}, ] [package.dependencies] @@ -1757,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "6718d680fa89f9518a232c1110ba43958d3e21c54c4dbd9129effa4f40a02b81" +content-hash = "4bfb67dfb72b02c845376211f7f958b2ece8c985944fbd03d246c858e846ddf6" diff --git a/pyproject.toml b/pyproject.toml index e3fcd0e..c636d5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.7.0" twine = "^4.0.2" -mkdocs = "^1.4.3" +mkdocs = "^1.5.2" mkdocstrings = {extras = ["python"], version = "^0.22.0"} mkdocs-material = "^9.1.21" pytest = "^7.4.0" From 83f8438c4fc6a3b561c0a6881fa5f46c74d993bf Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Aug 2023 20:10:12 +0000 Subject: [PATCH 22/38] Bump fastapi from 0.100.1 to 0.101.0 Bumps [fastapi](https://github.com/tiangolo/fastapi) from 0.100.1 to 0.101.0. - [Release notes](https://github.com/tiangolo/fastapi/releases) - [Commits](https://github.com/tiangolo/fastapi/compare/0.100.1...0.101.0) --- updated-dependencies: - dependency-name: fastapi dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1dcbfe6..667d88d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -384,17 +384,17 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.100.1" +version = "0.101.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = true python-versions = ">=3.7" files = [ - {file = "fastapi-0.100.1-py3-none-any.whl", hash = "sha256:ec6dd52bfc4eff3063cfcd0713b43c87640fefb2687bbbe3d8a08d94049cdf32"}, - {file = "fastapi-0.100.1.tar.gz", hash = "sha256:522700d7a469e4a973d92321ab93312448fbe20fca9c8da97effc7e7bc56df23"}, + {file = "fastapi-0.101.0-py3-none-any.whl", hash = "sha256:494eb3494d89e8079c20859d7ca695f66eaccc40f46fe8c75ab6186d15f05ffd"}, + {file = "fastapi-0.101.0.tar.gz", hash = "sha256:ca2ae65fe42f6a34b5cf6c994337149154b1b400c39809d7b2dccdceb5ae77af"}, ] [package.dependencies] -pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<3.0.0" +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" starlette = ">=0.27.0,<0.28.0" typing-extensions = ">=4.5.0" From f6a7850e1a316c5168ba51cbdbb669d774cd0c15 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 8 Aug 2023 14:30:58 -0400 Subject: [PATCH 23/38] Update llama.cpp --- llama_cpp/llama_cpp.py | 2 ++ vendor/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 423a4a0..bbb2a1e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -181,6 +181,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // Keep the booleans together to avoid misalignment during copy-by-value. # bool low_vram; // if true, reduce VRAM usage at the cost of performance +# bool mul_mat_q; // if true, use experimental mul_mat_q kernels # bool f16_kv; // use fp16 for KV cache # bool logits_all; // the llama_eval() call computes all logits, not just the last one # bool vocab_only; // only load the vocabulary, no weights @@ -203,6 +204,7 @@ class llama_context_params(Structure): ("progress_callback", llama_progress_callback), ("progress_callback_user_data", c_void_p), ("low_vram", c_bool), + ("mul_mat_q", c_bool), ("f16_kv", c_bool), ("logits_all", c_bool), ("vocab_only", c_bool), diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 41c6741..f5bfea0 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 41c674161fb2459bdf7806d1eebead15bc5d046e +Subproject commit f5bfea0580e417f99850d5456ca541d871a3e48c From d015bdb4f8ab5591a9147443ec3e0d4f1d0a3192 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 8 Aug 2023 14:35:06 -0400 Subject: [PATCH 24/38] Add mul_mat_q option --- llama_cpp/llama.py | 4 ++++ llama_cpp/server/app.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 47f71e9..9a8c090 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -227,6 +227,7 @@ class Llama: rope_freq_scale: float = 1.0, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) + mul_mat_q: Optional(bool) = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -293,6 +294,9 @@ class Llama: if rms_norm_eps is not None: self.params.rms_norm_eps = rms_norm_eps + if mul_mat_q is not None: + self.params.mul_mat_q = mul_mat_q + self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 4afcfd5..3d5238b 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -103,6 +103,10 @@ class Settings(BaseSettings): default=None, description="TEMPORARY", ) + mul_mat_q: Optional[bool] = Field( + default=None, + description="TEMPORARY", + ) class ErrorResponse(TypedDict): From 1e844d323824fd5dc62906b005daeee7c6707efd Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 8 Aug 2023 15:07:28 -0400 Subject: [PATCH 25/38] fix --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9244b8b..56143c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -230,7 +230,7 @@ class Llama: grammar: Optional[Union[str, Path]] = None, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) - mul_mat_q: Optional(bool) = None, # (TEMPORARY) + mul_mat_q: Optional[bool] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. From 66fb0345e8ade77d8eb3ae8103d1954e9a88d7ef Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 8 Aug 2023 15:08:54 -0400 Subject: [PATCH 26/38] Move grammar to function call argument --- llama_cpp/llama.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 56143c9..a996d5c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -227,7 +227,6 @@ class Llama: tensor_split: Optional[List[float]] = None, rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, - grammar: Optional[Union[str, Path]] = None, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) mul_mat_q: Optional[bool] = None, # (TEMPORARY) @@ -254,7 +253,6 @@ class Llama: tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. rope_freq_base: Base frequency for rope sampling. rope_freq_scale: Scale factor for rope sampling. - grammar: Path to a BNF grammar file to use for grammar based sampling. verbose: Print verbose output to stderr. Raises: @@ -383,12 +381,6 @@ class Llama: self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx, self._n_vocab), dtype=np.single ) - if grammar is not None: - self.grammar = LlamaGrammar.from_file( - grammar, verbose=verbose - ) # type: Optional[LlamaGrammar] - else: - self.grammar = None @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -527,6 +519,7 @@ class Llama: mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): assert self.ctx is not None assert self.n_tokens > 0 @@ -574,11 +567,11 @@ class Llama: if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) - if self.grammar is not None: + if grammar is not None: llama_cpp.llama_sample_grammar( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - grammar=self.grammar.grammar, + grammar=grammar.grammar, ) if temp.value == 0.0: @@ -650,10 +643,10 @@ class Llama: ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) - if self.grammar is not None: + if grammar is not None: llama_cpp.llama_grammar_accept_token( ctx=self.ctx, - grammar=self.grammar.grammar, + grammar=grammar.grammar, token=llama_cpp.ctypes.c_int(id), ) return id @@ -672,6 +665,7 @@ class Llama: mirostat_tau: float = 5.0, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): """Sample a token from the model. @@ -705,6 +699,7 @@ class Llama: mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, logits_processor=logits_processor, + grammar=grammar, ) def generate( @@ -723,6 +718,7 @@ class Llama: mirostat_eta: float = 0.1, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -761,8 +757,8 @@ class Llama: if reset: self.reset() - if self.grammar is not None: - self.grammar.reset() + if grammar is not None: + grammar.reset() while True: self.eval(tokens) @@ -778,6 +774,7 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, logits_processor=logits_processor, + grammar=grammar, ) if stopping_criteria is not None and stopping_criteria( self._input_ids.tolist(), self._scores[-1, :].tolist() @@ -880,6 +877,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -957,6 +955,7 @@ class Llama: repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -1301,6 +1300,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1345,6 +1345,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1374,6 +1375,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1418,6 +1420,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ) def _convert_text_completion_to_chat( @@ -1498,6 +1501,7 @@ class Llama: mirostat_eta: float = 0.1, model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1540,6 +1544,7 @@ class Llama: mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, + grammar=grammar, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore From 17dd7fa8e02b48c30c5a80f0225a9b93a6f5d4b8 Mon Sep 17 00:00:00 2001 From: Hannes Krumbiegel Date: Fri, 11 Aug 2023 09:58:48 +0200 Subject: [PATCH 27/38] Add py.typed --- llama_cpp/py.typed | 0 setup.py | 1 + 2 files changed, 1 insertion(+) create mode 100644 llama_cpp/py.typed diff --git a/llama_cpp/py.typed b/llama_cpp/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 8e6139d..74040d5 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ setup( author_email="abetlen@gmail.com", license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, + package_data={"llama_cpp": ["py.typed"]}, packages=["llama_cpp", "llama_cpp.server"], install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1"], extras_require={ From d018c7b01dd08518b59ebcedc603111243a39391 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sat, 12 Aug 2023 18:41:47 +0800 Subject: [PATCH 28/38] Add doc string for n_gpu_layers argument --- llama_cpp/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a996d5c..20a5e0c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -239,6 +239,7 @@ class Llama: n_ctx: Maximum context size. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. seed: Random seed. -1 for random. + n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded. f16_kv: Use half-precision for key/value cache. logits_all: Return logits for all tokens, not just the last token. vocab_only: Only load the vocabulary no weights. From c471871d0bfe0ac8d9bf69f6cee6cff768776ad2 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sun, 13 Aug 2023 11:21:28 +0800 Subject: [PATCH 29/38] make n_gpu_layers=-1 offload all layers --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 20a5e0c..8115d46 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -268,7 +268,7 @@ class Llama: self.params = llama_cpp.llama_context_default_params() self.params.n_ctx = n_ctx - self.params.n_gpu_layers = n_gpu_layers + self.params.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers self.params.seed = seed self.params.f16_kv = f16_kv self.params.logits_all = logits_all From 077f8ed23e8e282d94e12de32bf699f008a4b9fe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Aug 2023 20:36:56 +0000 Subject: [PATCH 30/38] Bump fastapi from 0.101.0 to 0.101.1 Bumps [fastapi](https://github.com/tiangolo/fastapi) from 0.101.0 to 0.101.1. - [Release notes](https://github.com/tiangolo/fastapi/releases) - [Commits](https://github.com/tiangolo/fastapi/compare/0.101.0...0.101.1) --- updated-dependencies: - dependency-name: fastapi dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0157ab6..ef2f0d7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -384,13 +384,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.101.0" +version = "0.101.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = true python-versions = ">=3.7" files = [ - {file = "fastapi-0.101.0-py3-none-any.whl", hash = "sha256:494eb3494d89e8079c20859d7ca695f66eaccc40f46fe8c75ab6186d15f05ffd"}, - {file = "fastapi-0.101.0.tar.gz", hash = "sha256:ca2ae65fe42f6a34b5cf6c994337149154b1b400c39809d7b2dccdceb5ae77af"}, + {file = "fastapi-0.101.1-py3-none-any.whl", hash = "sha256:aef5f8676eb1b8389952e1fe734abe20f04b71f6936afcc53b320ba79b686a4b"}, + {file = "fastapi-0.101.1.tar.gz", hash = "sha256:7b32000d14ca9992f7461117b81e4ef9ff0c07936af641b4fe40e67d5f9d63cb"}, ] [package.dependencies] From e91969c88899bb81e03710fa06019c472b7926c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Aug 2023 20:38:42 +0000 Subject: [PATCH 31/38] Bump sse-starlette from 1.6.1 to 1.6.5 Bumps [sse-starlette](https://github.com/sysid/sse-starlette) from 1.6.1 to 1.6.5. - [Release notes](https://github.com/sysid/sse-starlette/releases) - [Changelog](https://github.com/sysid/sse-starlette/blob/main/CHANGELOG.md) - [Commits](https://github.com/sysid/sse-starlette/compare/v1.6.1...v1.6.5) --- updated-dependencies: - dependency-name: sse-starlette dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0157ab6..eb3a47f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1562,13 +1562,13 @@ files = [ [[package]] name = "sse-starlette" -version = "1.6.1" +version = "1.6.5" description = "\"SSE plugin for Starlette\"" optional = true python-versions = ">=3.8" files = [ - {file = "sse-starlette-1.6.1.tar.gz", hash = "sha256:6208af2bd7d0887c92f1379da14bd1f4db56bd1274cc5d36670c683d2aa1de6a"}, - {file = "sse_starlette-1.6.1-py3-none-any.whl", hash = "sha256:d8f18f1c633e355afe61cc5e9c92eea85badcb8b2d56ec8cfb0a006994aa55da"}, + {file = "sse-starlette-1.6.5.tar.gz", hash = "sha256:819f2c421fb37067380fe3dcaba246c476b02651b7bb7601099a378ad802a0ac"}, + {file = "sse_starlette-1.6.5-py3-none-any.whl", hash = "sha256:68b6b7eb49be0c72a2af80a055994c13afcaa4761b29226beb208f954c25a642"}, ] [package.dependencies] From 485ad97909383a10c0d23d6cf6cff7a9bacd07ec Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Aug 2023 20:40:16 +0000 Subject: [PATCH 32/38] Bump pydantic-settings from 2.0.2 to 2.0.3 Bumps [pydantic-settings](https://github.com/pydantic/pydantic-settings) from 2.0.2 to 2.0.3. - [Release notes](https://github.com/pydantic/pydantic-settings/releases) - [Commits](https://github.com/pydantic/pydantic-settings/compare/v2.0.2...v2.0.3) --- updated-dependencies: - dependency-name: pydantic-settings dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0157ab6..872014a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1133,13 +1133,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydantic-settings" -version = "2.0.2" +version = "2.0.3" description = "Settings management using Pydantic" optional = true python-versions = ">=3.7" files = [ - {file = "pydantic_settings-2.0.2-py3-none-any.whl", hash = "sha256:6183a2abeab465d5a3ab69758e9a22d38b0cc2ba193f0b85f6971a252ea630f6"}, - {file = "pydantic_settings-2.0.2.tar.gz", hash = "sha256:342337fff50b23585e807a86dec85037900972364435c55c2fc00d16ff080539"}, + {file = "pydantic_settings-2.0.3-py3-none-any.whl", hash = "sha256:ddd907b066622bd67603b75e2ff791875540dc485b7307c4fffc015719da8625"}, + {file = "pydantic_settings-2.0.3.tar.gz", hash = "sha256:962dc3672495aad6ae96a4390fac7e593591e144625e5112d359f8f67fb75945"}, ] [package.dependencies] From b345d6098766d40ad9f83d20c9056fcab9fd1ae0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 14 Aug 2023 22:33:30 -0400 Subject: [PATCH 33/38] Update llama.cpp --- llama_cpp/llama_cpp.py | 32 ++++++++++++++++++++++++++++++++ vendor/llama.cpp | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 27348b0..0fd3209 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -158,6 +158,25 @@ llama_token_data_array_p = POINTER(llama_token_data_array) llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) +# enum llama_log_level { +# LLAMA_LOG_LEVEL_ERROR = 2, +# LLAMA_LOG_LEVEL_WARN = 3, +# LLAMA_LOG_LEVEL_INFO = 4 +# }; +LLAMA_LOG_LEVEL_ERROR = c_int(2) +LLAMA_LOG_LEVEL_WARN = c_int(3) +LLAMA_LOG_LEVEL_INFO = c_int(4) + + +# // Signature for logging events +# // Note that text includes the new line character at the end for most events. +# // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it +# // if it exists. +# // It might not exist for progress report where '.' is output repeatedly. +# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); +llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p) + + # struct llama_context_params { # uint32_t seed; // RNG seed, -1 for random # int32_t n_ctx; // text context @@ -351,6 +370,19 @@ class llama_timings(Structure): ] +# // Set callback for all future logging events. +# // If this is not called, or NULL is supplied, everything is output on stderr. +# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); +def llama_log_set( + log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore +): + return _lib.llama_log_set(log_callback, user_data) + + +_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p] +_lib.llama_log_set.restype = None + + # LLAMA_API int llama_max_devices(); def llama_max_devices() -> int: return _lib.llama_max_devices() diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f5bfea0..3ebb009 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f5bfea0580e417f99850d5456ca541d871a3e48c +Subproject commit 3ebb00935f3f0522b75df49c2769ab1774b91380 From 5788f1f2b2c219a4644137a5ac90d89f08f15ed7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 14 Aug 2023 22:41:37 -0400 Subject: [PATCH 34/38] Remove unnused import --- llama_cpp/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a996d5c..2a6f7cf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,5 +1,4 @@ import os -from pathlib import Path import sys import uuid import time From a240aa6b25be46aca06a9f3d0fc5b28528385638 Mon Sep 17 00:00:00 2001 From: c0sogi Date: Thu, 17 Aug 2023 21:00:44 +0900 Subject: [PATCH 35/38] Fix typos in llama_grammar --- llama_cpp/llama_grammar.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index f35f9fa..8ff1565 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1031,10 +1031,10 @@ def print_grammar_char(file: TextIO, c: int) -> None: # } def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( - llama_gretype.LLAMA_GRETYPE_CHAR.value, - llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + llama_gretype.LLAMA_GRETYPE_CHAR, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ) @@ -1054,9 +1054,10 @@ def print_rule( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1100,8 +1101,10 @@ def print_rule( # } for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype - if case is llama_gretype.LLAMA_GRETYPE_END.value: - raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) + if case is llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "unexpected end of rule: " + str(rule_id) + "," + str(i) + ) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1140,8 +1143,8 @@ def print_rule( # fprintf(file, "] "); if is_char_element(elem): if rule[i + 1].type in ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ): pass else: From da1ef72c51059b30a5719331b021972f0aa6cdfc Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 17 Aug 2023 23:02:20 -0400 Subject: [PATCH 36/38] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 3ebb009..604b8bd 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 3ebb00935f3f0522b75df49c2769ab1774b91380 +Subproject commit 604b8bdfa6320bbcb018eebcc1252dfede603c6b From 8fc3fa9f1cfa0d0d2e11e152ca7dd9d50cceef64 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 17 Aug 2023 23:17:56 -0400 Subject: [PATCH 37/38] Bump version --- CHANGELOG.md | 7 +++++++ pyproject.toml | 2 +- setup.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca220e..df635fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.78] + +### Added + +- Grammar based sampling via LlamaGrammar which can be passed to completions +- Make n_gpu_layers == -1 offload all layers + ## [0.1.77] - (llama.cpp) Update llama.cpp add support for LLaMa 2 70B diff --git a/pyproject.toml b/pyproject.toml index c636d5d..8735b60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llama_cpp_python" -version = "0.1.77" +version = "0.1.78" description = "Python bindings for the llama.cpp library" authors = ["Andrei Betlen "] license = "MIT" diff --git a/setup.py b/setup.py index 74040d5..bdc5a2e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( description="A Python wrapper for llama.cpp", long_description=long_description, long_description_content_type="text/markdown", - version="0.1.77", + version="0.1.78", author="Andrei Betlen", author_email="abetlen@gmail.com", license="MIT", From bbbf0f4fc47bd5f9880b799c82ad7c06f5003cb7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 24 Aug 2023 00:17:00 -0400 Subject: [PATCH 38/38] Update llama.cpp --- llama_cpp/llama.py | 30 +- llama_cpp/llama_cpp.py | 718 +++++++++++++++++++++++------------------ vendor/llama.cpp | 2 +- 3 files changed, 415 insertions(+), 335 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 21c0875..bfcae18 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -371,8 +371,8 @@ class Llama: sorted=sorted, ) self._candidates = candidates - self._token_nl = Llama.token_nl() - self._token_eos = Llama.token_eos() + self._token_nl = self.token_nl() + self._token_eos = self.token_eos() self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) @@ -450,10 +450,14 @@ class Llama: """ assert self.ctx is not None output = b"" + buffer_size = 32 + buffer = (ctypes.c_char * buffer_size)() for token in tokens: - output += llama_cpp.llama_token_to_str( - self.ctx, llama_cpp.llama_token(token) + n = llama_cpp.llama_token_to_str( + self.ctx, llama_cpp.llama_token(token), buffer, buffer_size ) + assert n <= buffer_size + output += bytes(buffer[:n]) return output def set_cache(self, cache: Optional[BaseLlamaCache]): @@ -1681,20 +1685,20 @@ class Llama: assert self.ctx is not None return LlamaTokenizer(self) - @staticmethod - def token_eos() -> int: + def token_eos(self) -> int: """Return the end-of-sequence token.""" - return llama_cpp.llama_token_eos() + assert self.ctx is not None + return llama_cpp.llama_token_eos(self.ctx) - @staticmethod - def token_bos() -> int: + def token_bos(self) -> int: """Return the beginning-of-sequence token.""" - return llama_cpp.llama_token_bos() + assert self.ctx is not None + return llama_cpp.llama_token_bos(self.ctx) - @staticmethod - def token_nl() -> int: + def token_nl(self) -> int: """Return the newline token.""" - return llama_cpp.llama_token_nl() + assert self.ctx is not None + return llama_cpp.llama_token_nl(self.ctx) @staticmethod def logits_to_logprobs(logits: List[float]) -> List[float]: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 0fd3209..c9200c6 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -90,26 +90,17 @@ GGML_USE_CUBLAS = hasattr(_lib, "ggml_init_cublas") GGML_CUDA_MAX_DEVICES = ctypes.c_int(16) LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes.c_int(1) -# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' -LLAMA_FILE_MAGIC_GGJT = ctypes.c_uint(0x67676A74) -# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' -LLAMA_FILE_MAGIC_GGLA = ctypes.c_uint(0x67676C61) -# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf' -LLAMA_FILE_MAGIC_GGMF = ctypes.c_uint(0x67676D66) -# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' -LLAMA_FILE_MAGIC_GGML = ctypes.c_uint(0x67676D6C) -# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +# define LLAMA_DEFAULT_SEED 0xFFFFFFFF +LLAMA_DEFAULT_SEED = ctypes.c_int(0xFFFFFFFF) + +# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E) -# #define LLAMA_FILE_VERSION 3 -LLAMA_FILE_VERSION = c_int(3) -LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT -LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML +# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN -LLAMA_SESSION_VERSION = c_int(1) +# define LLAMA_SESSION_VERSION 1 +LLAMA_SESSION_VERSION = ctypes.c_int(1) -# #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -LLAMA_DEFAULT_SEED = c_int(0xFFFFFFFF) # struct llama_model; llama_model_p = c_void_p @@ -122,6 +113,82 @@ llama_context_p = c_void_p llama_token = c_int llama_token_p = POINTER(llama_token) +# enum llama_log_level { +# LLAMA_LOG_LEVEL_ERROR = 2, +# LLAMA_LOG_LEVEL_WARN = 3, +# LLAMA_LOG_LEVEL_INFO = 4 +# }; +LLAMA_LOG_LEVEL_ERROR = c_int(2) +LLAMA_LOG_LEVEL_WARN = c_int(3) +LLAMA_LOG_LEVEL_INFO = c_int(4) + +# enum llama_vocab_type { +# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece +# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding +# }; +LLAMA_VOCAB_TYPE_SPM = c_int(0) +LLAMA_VOCAB_TYPE_BPE = c_int(1) + + +# enum llama_token_type { +# LLAMA_TOKEN_TYPE_UNDEFINED = 0, +# LLAMA_TOKEN_TYPE_NORMAL = 1, +# LLAMA_TOKEN_TYPE_UNKNOWN = 2, +# LLAMA_TOKEN_TYPE_CONTROL = 3, +# LLAMA_TOKEN_TYPE_USER_DEFINED = 4, +# LLAMA_TOKEN_TYPE_UNUSED = 5, +# LLAMA_TOKEN_TYPE_BYTE = 6, +# }; +LLAMA_TOKEN_TYPE_UNDEFINED = c_int(0) +LLAMA_TOKEN_TYPE_NORMAL = c_int(1) +LLAMA_TOKEN_TYPE_UNKNOWN = c_int(2) +LLAMA_TOKEN_TYPE_CONTROL = c_int(3) +LLAMA_TOKEN_TYPE_USER_DEFINED = c_int(4) +LLAMA_TOKEN_TYPE_UNUSED = c_int(5) +LLAMA_TOKEN_TYPE_BYTE = c_int(6) + +# enum llama_ftype { +# LLAMA_FTYPE_ALL_F32 = 0, +# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 +# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed +# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed +# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors +# +# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file +# }; +LLAMA_FTYPE_ALL_F32 = c_int(0) +LLAMA_FTYPE_MOSTLY_F16 = c_int(1) +LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) +LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) +LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4) +LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) +LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) +LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) +LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10) +LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11) +LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12) +LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13) +LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14) +LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15) +LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16) +LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17) +LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18) +LLAMA_FTYPE_GUESSED = c_int(1024) + # typedef struct llama_token_data { # llama_token id; // token id @@ -157,35 +224,13 @@ llama_token_data_array_p = POINTER(llama_token_data_array) # typedef void (*llama_progress_callback)(float progress, void *ctx); llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) - -# enum llama_log_level { -# LLAMA_LOG_LEVEL_ERROR = 2, -# LLAMA_LOG_LEVEL_WARN = 3, -# LLAMA_LOG_LEVEL_INFO = 4 -# }; -LLAMA_LOG_LEVEL_ERROR = c_int(2) -LLAMA_LOG_LEVEL_WARN = c_int(3) -LLAMA_LOG_LEVEL_INFO = c_int(4) - - -# // Signature for logging events -# // Note that text includes the new line character at the end for most events. -# // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it -# // if it exists. -# // It might not exist for progress report where '.' is output repeatedly. -# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); -llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p) - - # struct llama_context_params { # uint32_t seed; // RNG seed, -1 for random # int32_t n_ctx; // text context # int32_t n_batch; // prompt processing batch size -# int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) -# float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams) # int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t main_gpu; // the GPU that is used for scratch and small tensors -# + # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) # // ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -213,11 +258,9 @@ class llama_context_params(Structure): ("seed", c_uint32), ("n_ctx", c_int32), ("n_batch", c_int32), - ("n_gqa", c_int32), - ("rms_norm_eps", c_float), ("n_gpu_layers", c_int32), ("main_gpu", c_int32), - ("tensor_split", POINTER(c_float)), + ("tensor_split", c_float_p), ("rope_freq_base", c_float), ("rope_freq_scale", c_float), ("progress_callback", llama_progress_callback), @@ -235,50 +278,20 @@ class llama_context_params(Structure): llama_context_params_p = POINTER(llama_context_params) -# enum llama_ftype { -# LLAMA_FTYPE_ALL_F32 = 0, -# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 -# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed -# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed -# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors -# }; -LLAMA_FTYPE_ALL_F32 = c_int(0) -LLAMA_FTYPE_MOSTLY_F16 = c_int(1) -LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) -LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) -LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4) -LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) -LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) -LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) -LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10) -LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11) -LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12) -LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13) -LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14) -LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15) -LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16) -LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17) -LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18) + +# // Signature for logging events +# // Note that text includes the new line character at the end for most events. +# // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it +# // if it exists. +# // It might not exist for progress report where '.' is output repeatedly. +# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); +llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p) # // model quantization parameters # typedef struct llama_model_quantize_params { # int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() -# enum llama_ftype ftype; // quantize to this llama_ftype +# enum llama_ftype ftype; // quantize to this llama_ftype # bool allow_requantize; // allow quantizing non-f32/f16 tensors # bool quantize_output_tensor; // quantize output.weight # } llama_model_quantize_params; @@ -370,29 +383,7 @@ class llama_timings(Structure): ] -# // Set callback for all future logging events. -# // If this is not called, or NULL is supplied, everything is output on stderr. -# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); -def llama_log_set( - log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore -): - return _lib.llama_log_set(log_callback, user_data) - - -_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p] -_lib.llama_log_set.restype = None - - -# LLAMA_API int llama_max_devices(); -def llama_max_devices() -> int: - return _lib.llama_max_devices() - - -_lib.llama_max_devices.argtypes = [] -_lib.llama_max_devices.restype = c_int - - -# LLAMA_API struct llama_context_params llama_context_default_params(); +# LLAMA_API struct llama_context_params llama_context_default_params(void); def llama_context_default_params() -> llama_context_params: return _lib.llama_context_default_params() @@ -401,7 +392,7 @@ _lib.llama_context_default_params.argtypes = [] _lib.llama_context_default_params.restype = llama_context_params -# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); +# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); def llama_model_quantize_default_params() -> llama_model_quantize_params: return _lib.llama_model_quantize_default_params() @@ -410,25 +401,6 @@ _lib.llama_model_quantize_default_params.argtypes = [] _lib.llama_model_quantize_default_params.restype = llama_model_quantize_params -# LLAMA_API bool llama_mmap_supported(); -def llama_mmap_supported() -> bool: - return _lib.llama_mmap_supported() - - -_lib.llama_mmap_supported.argtypes = [] -_lib.llama_mmap_supported.restype = c_bool - - -# LLAMA_API bool llama_mlock_supported(); -def llama_mlock_supported() -> bool: - return _lib.llama_mlock_supported() - - -_lib.llama_mlock_supported.argtypes = [] -_lib.llama_mlock_supported.restype = c_bool - - -# // TODO: not great API - very likely to change # // Initialize the llama + ggml backend # // If numa is true, use NUMA optimizations # // Call once at the start of the program @@ -442,7 +414,7 @@ _lib.llama_backend_init.restype = None # // Call once at the end of the program - currently only used for MPI -# LLAMA_API void llama_backend_free(); +# LLAMA_API void llama_backend_free(void); def llama_backend_free(): return _lib.llama_backend_free() @@ -452,7 +424,7 @@ _lib.llama_backend_free.restype = None # LLAMA_API struct llama_model * llama_load_model_from_file( -# const char * path_model, +# const char * path_model, # struct llama_context_params params); def llama_load_model_from_file( path_model: bytes, params: llama_context_params @@ -474,7 +446,7 @@ _lib.llama_free_model.restype = None # LLAMA_API struct llama_context * llama_new_context_with_model( -# struct llama_model * model, +# struct llama_model * model, # struct llama_context_params params); def llama_new_context_with_model( model: llama_model_p, params: llama_context_params @@ -486,7 +458,17 @@ _lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_param _lib.llama_new_context_with_model.restype = llama_context_p -# LLAMA_API int64_t llama_time_us(); +# // Frees all allocated memory +# LLAMA_API void llama_free(struct llama_context * ctx); +def llama_free(ctx: llama_context_p): + return _lib.llama_free(ctx) + + +_lib.llama_free.argtypes = [llama_context_p] +_lib.llama_free.restype = None + + +# LLAMA_API int64_t llama_time_us(void); def llama_time_us() -> int: return _lib.llama_time_us() @@ -495,30 +477,95 @@ _lib.llama_time_us.argtypes = [] _lib.llama_time_us.restype = ctypes.c_int64 -# // Various functions for loading a ggml llama model. -# // Allocate (almost) all memory needed for the model. -# // Return NULL on failure -# LLAMA_API struct llama_context * llama_init_from_file( -# const char * path_model, -# struct llama_context_params params); -def llama_init_from_file( - path_model: bytes, params: llama_context_params -) -> llama_context_p: - return _lib.llama_init_from_file(path_model, params) +# LLAMA_API int llama_max_devices (void); +def llama_max_devices() -> int: + return _lib.llama_max_devices() -_lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] -_lib.llama_init_from_file.restype = llama_context_p +_lib.llama_max_devices.argtypes = [] +_lib.llama_max_devices.restype = c_int -# Frees all allocated memory -# LLAMA_API void llama_free(struct llama_context * ctx); -def llama_free(ctx: llama_context_p): - return _lib.llama_free(ctx) +# LLAMA_API bool llama_mmap_supported (void); +def llama_mmap_supported() -> bool: + return _lib.llama_mmap_supported() -_lib.llama_free.argtypes = [llama_context_p] -_lib.llama_free.restype = None +_lib.llama_mmap_supported.argtypes = [] +_lib.llama_mmap_supported.restype = c_bool + + +# LLAMA_API bool llama_mlock_supported(void); +def llama_mlock_supported() -> bool: + return _lib.llama_mlock_supported() + + +_lib.llama_mlock_supported.argtypes = [] +_lib.llama_mlock_supported.restype = c_bool + + +# LLAMA_API int llama_n_vocab(const struct llama_context * ctx); +def llama_n_vocab(ctx: llama_context_p) -> int: + return _lib.llama_n_vocab(ctx) + + +_lib.llama_n_vocab.argtypes = [llama_context_p] +_lib.llama_n_vocab.restype = c_int + + +# LLAMA_API int llama_n_ctx (const struct llama_context * ctx); +def llama_n_ctx(ctx: llama_context_p) -> int: + return _lib.llama_n_ctx(ctx) + + +_lib.llama_n_ctx.argtypes = [llama_context_p] +_lib.llama_n_ctx.restype = c_int + + +# LLAMA_API int llama_n_embd (const struct llama_context * ctx); +def llama_n_embd(ctx: llama_context_p) -> int: + return _lib.llama_n_embd(ctx) + + +_lib.llama_n_embd.argtypes = [llama_context_p] +_lib.llama_n_embd.restype = c_int + + +# LLAMA_API int llama_model_n_vocab(const struct llama_model * model); +def llama_model_n_vocab(model: llama_model_p) -> int: + return _lib.llama_model_n_vocab(model) + + +_lib.llama_model_n_vocab.argtypes = [llama_model_p] +_lib.llama_model_n_vocab.restype = c_int + + +# LLAMA_API int llama_model_n_ctx (const struct llama_model * model); +def llama_model_n_ctx(model: llama_model_p) -> int: + return _lib.llama_model_n_ctx(model) + + +_lib.llama_model_n_ctx.argtypes = [llama_model_p] +_lib.llama_model_n_ctx.restype = c_int + + +# LLAMA_API int llama_model_n_embd (const struct llama_model * model); +def llama_model_n_embd(model: llama_model_p) -> int: + return _lib.llama_model_n_embd(model) + + +_lib.llama_model_n_embd.argtypes = [llama_model_p] +_lib.llama_model_n_embd.restype = c_int + + +# // Get a string describing the model type +# LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); +def llama_model_type(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int: + return _lib.llama_model_type(model, buf, buf_size) + + +_lib.llama_model_type.argtypes = [llama_model_p, c_char_p, c_size_t] +_lib.llama_model_type.restype = c_int # // Returns 0 on success @@ -737,147 +784,17 @@ _lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int, c_int _lib.llama_eval_embd.restype = c_int -# Convert the provided text into tokens. -# The tokens pointer must be large enough to hold the resulting tokens. -# Returns the number of tokens on success, no more than n_max_tokens -# Returns a negative number on failure - the number of tokens that would have been returned -# TODO: not sure if correct -# LLAMA_API int llama_tokenize( -# struct llama_context * ctx, -# const char * text, -# llama_token * tokens, -# int n_max_tokens, -# bool add_bos); -def llama_tokenize( - ctx: llama_context_p, - text: bytes, - tokens, # type: Array[llama_token] - n_max_tokens: c_int, - add_bos: c_bool, -) -> int: - return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) +# // Export a static computation graph for context of 511 and batch size of 1 +# // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these +# // parameters here to keep things simple +# // IMPORTANT: do not use for anything else other than debugging and testing! +# LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); +def llama_eval_export(ctx: llama_context_p, fname: bytes) -> int: + return _lib.llama_eval_export(ctx, fname) -_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] -_lib.llama_tokenize.restype = c_int - - -# LLAMA_API int llama_tokenize_with_model( -# const struct llama_model * model, -# const char * text, -# llama_token * tokens, -# int n_max_tokens, -# bool add_bos); -def llama_tokenize_with_model( - model: llama_model_p, - text: bytes, - tokens, # type: Array[llama_token] - n_max_tokens: c_int, - add_bos: c_bool, -) -> int: - return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos) - - -# LLAMA_API int llama_n_vocab(const struct llama_context * ctx); -def llama_n_vocab(ctx: llama_context_p) -> int: - return _lib.llama_n_vocab(ctx) - - -_lib.llama_n_vocab.argtypes = [llama_context_p] -_lib.llama_n_vocab.restype = c_int - - -# LLAMA_API int llama_n_ctx (const struct llama_context * ctx); -def llama_n_ctx(ctx: llama_context_p) -> int: - return _lib.llama_n_ctx(ctx) - - -_lib.llama_n_ctx.argtypes = [llama_context_p] -_lib.llama_n_ctx.restype = c_int - - -# LLAMA_API int llama_n_embd (const struct llama_context * ctx); -def llama_n_embd(ctx: llama_context_p) -> int: - return _lib.llama_n_embd(ctx) - - -_lib.llama_n_embd.argtypes = [llama_context_p] -_lib.llama_n_embd.restype = c_int - - -# LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model); -def llama_n_vocab_from_model(model: llama_model_p) -> int: - return _lib.llama_n_vocab_from_model(model) - - -_lib.llama_n_vocab_from_model.argtypes = [llama_model_p] -_lib.llama_n_vocab_from_model.restype = c_int - - -# LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model); -def llama_n_ctx_from_model(model: llama_model_p) -> int: - return _lib.llama_n_ctx_from_model(model) - - -_lib.llama_n_ctx_from_model.argtypes = [llama_model_p] -_lib.llama_n_ctx_from_model.restype = c_int - - -# LLAMA_API int llama_n_embd_from_model (const struct llama_model * model); -def llama_n_embd_from_model(model: llama_model_p) -> int: - return _lib.llama_n_embd_from_model(model) - - -_lib.llama_n_embd_from_model.argtypes = [llama_model_p] -_lib.llama_n_embd_from_model.restype = c_int - - -# // Get the vocabulary as output parameters. -# // Returns number of results. -# LLAMA_API int llama_get_vocab( -# const struct llama_context * ctx, -# const char * * strings, -# float * scores, -# int capacity); -def llama_get_vocab( - ctx: llama_context_p, - strings, # type: Array[c_char_p] # type: ignore - scores, # type: Array[c_float] # type: ignore - capacity: c_int, -) -> int: - return _lib.llama_get_vocab(ctx, strings, scores, capacity) - - -_lib.llama_get_vocab.argtypes = [ - llama_context_p, - POINTER(c_char_p), - POINTER(c_float), - c_int, -] -_lib.llama_get_vocab.restype = c_int - - -# LLAMA_API int llama_get_vocab_from_model( -# const struct llama_model * model, -# const char * * strings, -# float * scores, -# int capacity); -def llama_get_vocab_from_model( - model: llama_model_p, - strings, # type: Array[c_char_p] # type: ignore - scores, # type: Array[c_float] # type: ignore - capacity: c_int, -) -> int: - return _lib.llama_get_vocab_from_model(model, strings, scores, capacity) - - -_lib.llama_get_vocab_from_model.argtypes = [ - llama_model_p, - POINTER(c_char_p), - POINTER(c_float), - c_int, -] -_lib.llama_get_vocab_from_model.restype = c_int +_lib.llama_eval_export.argtypes = [llama_context_p, c_char_p] +_lib.llama_eval_export.restype = c_int # Token logits obtained from the last call to llama_eval() @@ -909,16 +826,186 @@ _lib.llama_get_embeddings.argtypes = [llama_context_p] _lib.llama_get_embeddings.restype = c_float_p +# // +# // Vocab +# // + + +# LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); +def llama_token_get_text(ctx: llama_context_p, token: llama_token) -> bytes: + return _lib.llama_token_get_text(ctx, token) + + +_lib.llama_token_get_text.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_text.restype = c_char_p + + +# LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); +def llama_token_get_score(ctx: llama_context_p, token: llama_token) -> float: + return _lib.llama_token_get_score(ctx, token) + + +_lib.llama_token_get_score.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_score.restype = c_float + + +# LLAMA_API llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); +def llama_token_get_type(ctx: llama_context_p, token: llama_token) -> int: + return _lib.llama_token_get_type(ctx, token) + + +_lib.llama_token_get_type.argtypes = [llama_context_p, llama_token] +_lib.llama_token_get_type.restype = ctypes.c_int + + +# // Special tokens + + +# LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence +def llama_token_bos(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_bos(ctx) + + +_lib.llama_token_bos.argtypes = [llama_context_p] +_lib.llama_token_bos.restype = llama_token + + +# LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence +def llama_token_eos(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_eos(ctx) + + +_lib.llama_token_eos.argtypes = [llama_context_p] +_lib.llama_token_eos.restype = llama_token + + +# LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line +def llama_token_nl(ctx: llama_context_p) -> llama_token: + return _lib.llama_token_nl(ctx) + + +_lib.llama_token_nl.argtypes = [llama_context_p] +_lib.llama_token_nl.restype = llama_token + + +# // +# // Tokenization +# // + + +# Convert the provided text into tokens. +# The tokens pointer must be large enough to hold the resulting tokens. +# Returns the number of tokens on success, no more than n_max_tokens +# Returns a negative number on failure - the number of tokens that would have been returned +# TODO: not sure if correct +# LLAMA_API int llama_tokenize( +# struct llama_context * ctx, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize( + ctx: llama_context_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] +_lib.llama_tokenize.restype = c_int + + +# LLAMA_API int llama_tokenize_bpe( +# struct llama_context * ctx, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize_bpe( + ctx: llama_context_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize_bpe(ctx, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize_bpe.argtypes = [ + llama_context_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_bpe.restype = c_int + + +# LLAMA_API int llama_tokenize_with_model( +# const struct llama_model * model, +# const char * text, +# llama_token * tokens, +# int n_max_tokens, +# bool add_bos); +def llama_tokenize_with_model( + model: llama_model_p, + text: bytes, + tokens, # type: Array[llama_token] + n_max_tokens: c_int, + add_bos: c_bool, +) -> int: + return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos) + + +_lib.llama_tokenize_with_model.argtypes = [ + llama_model_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_with_model.restype = c_int + + # // Token Id -> String. Uses the vocabulary in the provided context -# LLAMA_API const char * llama_token_to_str( +# // Does not write null terminator to the buffer +# LLAMA_API int llama_token_to_str( # const struct llama_context * ctx, -# llama_token token); -def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes: - return _lib.llama_token_to_str(ctx, token) +# llama_token token, +# char * buf, +# int length); +def llama_token_to_str( + ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int +) -> int: + return _lib.llama_token_to_str(ctx, token, buf, length) -_lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] -_lib.llama_token_to_str.restype = c_char_p +_lib.llama_tokenize_with_model.argtypes = [ + llama_model_p, + c_char_p, + llama_token_p, + c_int, + c_bool, +] +_lib.llama_tokenize_with_model.restype = c_int + + +# LLAMA_API int llama_token_to_str_bpe( +# const struct llama_context * ctx, +# llama_token token, +# char * buf, +# int length); +def llama_token_to_str_bpe( + ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int +) -> int: + return _lib.llama_token_to_str_bpe(ctx, token, buf, length) + + +_lib.llama_token_to_str_bpe.argtypes = [llama_context_p, llama_token, c_char_p, c_int] +_lib.llama_token_to_str_bpe.restype = c_int # LLAMA_API const char * llama_token_to_str_with_model( @@ -931,38 +1018,12 @@ def llama_token_to_str_with_model(model: llama_model_p, token: llama_token) -> b _lib.llama_token_to_str_with_model.argtypes = [llama_model_p, llama_token] _lib.llama_token_to_str_with_model.restype = c_char_p -# Special tokens - - -# LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence -def llama_token_bos() -> int: - return _lib.llama_token_bos() - - -_lib.llama_token_bos.argtypes = [] -_lib.llama_token_bos.restype = llama_token - - -# LLAMA_API llama_token llama_token_eos(); // end-of-sentence -def llama_token_eos() -> int: - return _lib.llama_token_eos() - - -_lib.llama_token_eos.argtypes = [] -_lib.llama_token_eos.restype = llama_token - - -# LLAMA_API llama_token llama_token_nl(); // next-line -def llama_token_nl() -> int: - return _lib.llama_token_nl() - - -_lib.llama_token_nl.argtypes = [] -_lib.llama_token_nl.restype = llama_token - +# // # // Grammar # // + + # LLAMA_API struct llama_grammar * llama_grammar_init( # const llama_grammar_element ** rules, # size_t n_rules, @@ -992,7 +1053,9 @@ _lib.llama_grammar_free.argtypes = [llama_grammar_p] _lib.llama_grammar_free.restype = None -# Sampling functions +# // +# // Sampling functions +# // # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -1351,6 +1414,19 @@ def llama_print_system_info() -> bytes: _lib.llama_print_system_info.argtypes = [] _lib.llama_print_system_info.restype = c_char_p + +# // Set callback for all future logging events. +# // If this is not called, or NULL is supplied, everything is output on stderr. +# LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); +def llama_log_set( + log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore +): + return _lib.llama_log_set(log_callback, user_data) + + +_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p] +_lib.llama_log_set.restype = None + ################################################################################################### diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 604b8bd..f5fe98d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 604b8bdfa6320bbcb018eebcc1252dfede603c6b +Subproject commit f5fe98d11bdf9e7797bcfb05c0c3601ffc4b9d26