Compare commits
646 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b52672c3da | |||
| 9d6148316c | |||
| 7da0822456 | |||
| d35df0db71 | |||
| 93dc5dee6f | |||
| 2d8fad8230 | |||
| ca2958ff98 | |||
| f60ebc7bf2 | |||
| b072737193 | |||
| 3b509da571 | |||
| 5ddb6a191f | |||
| 1b5fb36c9d | |||
| 942f6eac94 | |||
| 2b3c1d81f0 | |||
| 1f21ef7488 | |||
| b799bca7a3 | |||
| b2b4a9ee7d | |||
| ed805f57ff | |||
| fa6f069577 | |||
| cd2280d1a3 | |||
| 5e5ad634a1 | |||
| 55a27a3fb8 | |||
| 8587cddd6c | |||
| 2bd8e5cb23 | |||
| bfe4baa6ed | |||
| 72a6d7dffe | |||
| afe2f0abe1 | |||
| 09fd007c6e | |||
| 24cf2a7954 | |||
| be3eb62047 | |||
| 9c32fed184 | |||
| 6435d69a6d | |||
| a2276177a3 | |||
| ebd0291ef2 | |||
| 0510ee056d | |||
| 44b572a9e0 | |||
| f9c2ad48c2 | |||
| c275aa4732 | |||
| ff071fc74c | |||
| 8d528e0045 | |||
| fd32e3d6e8 | |||
| 34be3f8be6 | |||
| 3037450c77 | |||
| b7091f93b1 | |||
| ab3cbfc99d | |||
| 26030266d2 | |||
| edda0e324b | |||
| 5407d12bc6 | |||
| 2de42ba690 | |||
| f3301a31d5 | |||
| e6a708aa04 | |||
| e80489135b | |||
| a53db44d40 | |||
| 0698ddb496 | |||
| 0962cbb2e5 | |||
| f69c47d9ae | |||
| 027fc1a85a | |||
| f84230527c | |||
| 0e64a48743 | |||
| ffa8b562e9 | |||
| 56b0104154 | |||
| c0c13e4ed4 | |||
| 89befcaf33 | |||
| 0f1c970179 | |||
| 57d3ac0c0b | |||
| a9f9c60efd | |||
| e109a8b502 | |||
| b81926def6 | |||
| 8cb7864110 | |||
| 7cd9f9ed48 | |||
| 2c2334d4db | |||
| 21ffadc2a6 | |||
| 241f966b1a | |||
| 7d0e4510b8 | |||
| 306e67f32d | |||
| 5c8d7d5d6f | |||
| 0b370f2dd9 | |||
| 887e8a8d84 | |||
| 189214a69d | |||
| cd6d24f111 | |||
| c01cfe4f9a | |||
| fbbe9e6030 | |||
| 43bca6d107 | |||
| 669c60a6bb | |||
| dd39003a9b | |||
| 4bded44b6a | |||
| ec22635b47 | |||
| 29d0541ac9 | |||
| a0f411c87d | |||
| 862d5224dd | |||
| e664bc7632 | |||
| f9052d7ecf | |||
| 7dff34ba4e | |||
| dbc25a386e | |||
| 0ea7d0ec80 | |||
| 1d28b4699b | |||
| e0ca46cd73 | |||
| 5454a55269 | |||
| 40c9a13476 | |||
| bd49bce278 | |||
| 52dd479214 | |||
| c57d5cbdde | |||
| 525caadd8c | |||
| f9fa7421cb | |||
| 342096b4bd | |||
| 55510cbad2 | |||
| 3ab50376b0 | |||
| f8fb61d4ad | |||
| 0d68446323 | |||
| 81dbf4309a | |||
| febfe1c268 | |||
| 2a5f86ed6d | |||
| d3659c8ca0 | |||
| f7f75de7c3 | |||
| f58902818d | |||
| 8da410ed95 | |||
| da44c196b6 | |||
| 36079c6646 | |||
| 135448f513 | |||
| 2e143fd15c | |||
| 0b9526b476 | |||
| f304bc63b8 | |||
| decc7851f2 | |||
| 97108db038 | |||
| 1f1fa71d0c | |||
| 2988334fe5 | |||
| 292d12bed4 | |||
| 509cff6e5c | |||
| 29520df44f | |||
| 9be42e49f9 | |||
| 42cef9c282 | |||
| 3a71099dac | |||
| 356122e990 | |||
| aefcdd6f7f | |||
| 3835a8d5df | |||
| e8188a56c7 | |||
| c42a18e9e5 | |||
| b73d221324 | |||
| cc51ffdb57 | |||
| c8971db435 | |||
| c4e787d47b | |||
| fb48b8f0c5 | |||
| 67600d0a0b | |||
| 5a9ab09bc3 | |||
| 2c06ec5f51 | |||
| d70e07fc45 | |||
| fff7203049 | |||
| 5663980015 | |||
| 8304a7716d | |||
| 523d8c38f9 | |||
| e6299960cc | |||
| fb6d41237c | |||
| e183744cb5 | |||
| 07112e4e98 | |||
| bc15f6cca3 | |||
| 3921fb973c | |||
| 6408b4ad53 | |||
| 326b146d68 | |||
| 1830db0476 | |||
| 3ba6043c62 | |||
| f4a74d3ac7 | |||
| e75f58420c | |||
| 28bb0e770f | |||
| 06f4df52f1 | |||
| a03cbcd5f9 | |||
| df67ae730b | |||
| 9305164bf3 | |||
| 453f4c5175 | |||
| 37a9979459 | |||
| 713f2f73da | |||
| 237499d102 | |||
| 3f811f52fd | |||
| 2ea8054304 | |||
| 488a30e879 | |||
| bc3f425212 | |||
| fd1d6c03cb | |||
| 58b52dfb2f | |||
| 651e92fbbf | |||
| 779619f742 | |||
| 96a5e9fc11 | |||
| eb537b5db4 | |||
| 2da79b13df | |||
| 885f88fb60 | |||
| 3585019831 | |||
| 6d7f3dbbb7 | |||
| 71cf7ad11a | |||
| b748fcf836 | |||
| 7289256114 | |||
| 870ebb8850 | |||
| 517b5c17d6 | |||
| d0ac8d9fc7 | |||
| 761a8ad39a | |||
| 52adc8873b | |||
| 173a5c6290 | |||
| f3b2303428 | |||
| 1870069f80 | |||
| d560f2d1f2 | |||
| f7e2ed20fa | |||
| 10d719ac1b | |||
| 45058b4105 | |||
| 2416b2b7af | |||
| 4263350c5b | |||
| 214047dee1 | |||
| ba0b77a803 | |||
| 6e2be3356d | |||
| 8e884fb3f1 | |||
| 59074df021 | |||
| f853e50589 | |||
| ca03358575 | |||
| ab6abc2c13 | |||
| 0ce35a117c | |||
| 900e848522 | |||
| aafe86d81a | |||
| 43b3a0ac66 | |||
| 02f639e561 | |||
| 76bc27199f | |||
| 1aa7027be1 | |||
| f961937097 | |||
| 7a427d7b03 | |||
| 66a1942524 | |||
| 1173adbe86 | |||
| a5beb6d8f0 | |||
| 0e3b7b6a39 | |||
| 5e705bc31b | |||
| 55ce601502 | |||
| 8f6ecd5c64 | |||
| a51a767407 | |||
| 2ea4dd30c6 | |||
| 80e578d3e3 | |||
| c52353cf8a | |||
| d76ebf0ec3 | |||
| 4be5070427 | |||
| e140c02d51 | |||
| 88643a1ba9 | |||
| b7b585656b | |||
| 4494c0b033 | |||
| aa6416399e | |||
| b313751acf | |||
| b1d05dfe8b | |||
| f8899af113 | |||
| cf29cba084 | |||
| ec9b868aea | |||
| 3ec6c71e43 | |||
| 4ad0083118 | |||
| 1055d4356a | |||
| 5822711ae6 | |||
| b19f5133c3 | |||
| 471ea81a7d | |||
| b1832faaae | |||
| 3a9a1bbb84 | |||
| d8081790f3 | |||
| 493bf8db7e | |||
| d9eba2a44f | |||
| fc061c2fee | |||
| aaa96713d4 | |||
| 02954c1a10 | |||
| 4355f30422 | |||
| 2f07df3177 | |||
| 672e9752a0 | |||
| df0f684c34 | |||
| 21afa134f0 | |||
| 6bcec1ac25 | |||
| fe331ed9bd | |||
| 746abf5e28 | |||
| 4d2c93a04f | |||
| 3959e3cadb | |||
| ec5fdb8b92 | |||
| c030ac1d85 | |||
| d223f7388d | |||
| 816d1344ee | |||
| 4c0c7f4c6e | |||
| 04b6ecadc4 | |||
| e84d952dc0 | |||
| 388130a122 | |||
| bb59057d5d | |||
| 36a4481152 | |||
| efa753678c | |||
| 7f3a567259 | |||
| defbe0f9e9 | |||
| 18862145e4 | |||
| 35558dadf4 | |||
| ae8059ca24 | |||
| 116984feb7 | |||
| 219af75704 | |||
| d76fa7fc37 | |||
| 7b6d14e62a | |||
| 67d707e851 | |||
| e648863d52 | |||
| a7cc1cf309 | |||
| f24db23458 | |||
| d132e344d7 | |||
| 22f41daded | |||
| 7c7feaa033 | |||
| 2f80bd9f87 | |||
| 23e5e8dde9 | |||
| e99aca98ab | |||
| 7e30e97a59 | |||
| db4dfea7ec | |||
| 17254a7692 | |||
| adf188c439 | |||
| 21958a55d1 | |||
| 947827bba0 | |||
| e4a3ffa9c1 | |||
| 1fa3737134 | |||
| e7844e9c8d | |||
| 1c761ae042 | |||
| 56ca84f243 | |||
| 04101bc59e | |||
| 0a247a50f2 | |||
| 0e2714acea | |||
| 36921a3e98 | |||
| c1a127c87c | |||
| c1750bb32d | |||
| 4699c226da | |||
| b05f9b6256 | |||
| 0679712d26 | |||
| cb54750e07 | |||
| 21c45ba0ac | |||
| c0c14e60b4 | |||
| 050b43108c | |||
| 00cc0c6a28 | |||
| bee13d9921 | |||
| f814787144 | |||
| c9bb0c587f | |||
| 8422196e89 | |||
| b70dd51cfa | |||
| 190c07975d | |||
| 011ed540dd | |||
| a9c405fac9 | |||
| 9c174e0940 | |||
| 5c4c4b8b7d | |||
| 764825bbff | |||
| ee4cc8ee3b | |||
| 4b53b89f09 | |||
| a2440f72f6 | |||
| 9c0f346258 | |||
| 11f029c311 | |||
| fb923d5efc | |||
| ace2cc6257 | |||
| 24ac577046 | |||
| e86bfd7667 | |||
| e4043633fc | |||
| a8132d1252 | |||
| 927f4d3a37 | |||
| 66f71c1836 | |||
| b1069196a6 | |||
| ba7248c669 | |||
| 6fc4e36625 | |||
| 7d7c2a62dd | |||
| 5b74df2bfc | |||
| 0c392e7a87 | |||
| f656dfcb32 | |||
| 0fab46f65c | |||
| 37dceb043e | |||
| 7ce374d3b9 | |||
| 6e4415e865 | |||
| 45bad9771d | |||
| 8d60db0f6f | |||
| 1bee519a6f | |||
| 72bfa115a0 | |||
| 7f85b2914d | |||
| b8076bb0bd | |||
| d35d923c76 | |||
| a654bc04f7 | |||
| a71e3f4d98 | |||
| 588962d24e | |||
| 2fa33dde81 | |||
| 7ac9088d5c | |||
| dd60bcbfb7 | |||
| b5cf0f0aef | |||
| 9a1e971126 | |||
| 088d65605a | |||
| c881209b92 | |||
| d7a2e3ddae | |||
| d5af593769 | |||
| df74f86955 | |||
| a3de843fdb | |||
| dc15bc508f | |||
| b8eb7c5fed | |||
| 548cedb869 | |||
| 702191049f | |||
| aea39eeafb | |||
| 23a3f01b2b | |||
| af118501b9 | |||
| d1d17f4f0a | |||
| 6832d60bc0 | |||
| ea95462998 | |||
| 847ee20390 | |||
| 867a96c051 | |||
| 0897e4350e | |||
| d2b10545db | |||
| 85993fbb5a | |||
| fb20a9e120 | |||
| 21b823dd3b | |||
| 618ed2c65f | |||
| 9f81c11ba0 | |||
| 5301c01776 | |||
| d81de2f3d8 | |||
| 1314b4b541 | |||
| 695eb04243 | |||
| e5fc916814 | |||
| 0878e5f4a8 | |||
| 72bcec0ce5 | |||
| d604b9622c | |||
| cf0dd777c8 | |||
| ec272ca8be | |||
| 99a44d87dc | |||
| 16f38abd25 | |||
| cac3c4d45f | |||
| 4167e2e294 | |||
| 6ddb9ee3e3 | |||
| 05aefeddc7 | |||
| 9db75fcfc2 | |||
| 1264275cc3 | |||
| cd6dc4ef7e | |||
| 8cd4a96686 | |||
| 344f3771cb | |||
| 8b851e2eeb | |||
| 24282dceb1 | |||
| 1f0bb8742f | |||
| 0de75505f3 | |||
| e5a244ad5d | |||
| 4433b83378 | |||
| 7049dba778 | |||
| 6405d389aa | |||
| b111f2a779 | |||
| b16186a32a | |||
| abdb4660d4 | |||
| ed3bcae8bd | |||
| 75c5136e5a | |||
| 1781c05adb | |||
| f613da4219 | |||
| d87655afff | |||
| a9da944a5d | |||
| efa778a0ef | |||
| 8b411b234d | |||
| ce7418e274 | |||
| 7c9beb5829 | |||
| 56e0c90445 | |||
| 490d37bb80 | |||
| ea238721f0 | |||
| d417ba2a48 | |||
| c713d01e72 | |||
| f95c6a221b | |||
| 718d4b013c | |||
| d9b9987ad3 | |||
| ba728f3e63 | |||
| d83efbb5bc | |||
| 3cb83404e9 | |||
| 1ae1e361b7 | |||
| 016b1e10d7 | |||
| c3ce6108e3 | |||
| cd67f60e01 | |||
| 07549c967a | |||
| 3d38d85287 | |||
| 6fc76ef954 | |||
| d132a3dfbb | |||
| a6dcc231f8 | |||
| c3d626eb07 | |||
| 6d1c5d4491 | |||
| 30c417fe70 | |||
| 6020db0243 | |||
| d9a7b83ae3 | |||
| 1d5a39e002 | |||
| fd61ae13e5 | |||
| ef67037f8e | |||
| 71c6b1ee99 | |||
| a1c81360a5 | |||
| d156942419 | |||
| 7042a748f5 | |||
| d9d937b7f7 | |||
| 65be657a79 | |||
| b197bb01d3 | |||
| a3ac142c83 | |||
| 342a0ad372 | |||
| 35d948b6e1 | |||
| 6c6d12033f | |||
| 556e0f4b43 | |||
| d50e0711c2 | |||
| e2e53d497f | |||
| 693f5786ac | |||
| 9ece1ce2de | |||
| 36a76bf9db | |||
| d0faf77208 | |||
| c8582fc4a2 | |||
| 60b67e2b47 | |||
| 2c7c30be69 | |||
| 6a320e8bfe | |||
| cb0deb5f9d | |||
| 766f4aae2b | |||
| 4e66d22151 | |||
| 8992babaa3 | |||
| 49043b7b7d | |||
| f2414bfd45 | |||
| 68fbcdaa06 | |||
| 7d91b436e4 | |||
| 40e2f8d9f0 | |||
| 4cb6735541 | |||
| 0351e4fa90 | |||
| 1b2d6c424c | |||
| 28c35d045d | |||
| 1f6a1f0028 | |||
| d7029489d6 | |||
| 12afccd9ca | |||
| 81f76111b0 | |||
| 96dac22194 | |||
| 2d36819503 | |||
| 8e20a7e035 | |||
| 4920c5940f | |||
| 3744118311 | |||
| 5ada0b95e9 | |||
| 19eaf5d956 | |||
| 365d175100 | |||
| c3ca68d25b | |||
| eaa9ceeb43 | |||
| 949fac192f | |||
| 4b96d10bc3 | |||
| c16870277c | |||
| 247e3c1470 | |||
| 2af4af6390 | |||
| 749e9977a0 | |||
| 1c61ab6bd9 | |||
| e9f1a8e39b | |||
| b6a51c955e | |||
| 634c1f6752 | |||
| 6ebb816e56 | |||
| 37862f74fa | |||
| 67546746d4 | |||
| d44b6b7f1b | |||
| 3576f44a57 | |||
| 4768ea624d | |||
| e3f9894caf | |||
| 19c8ad3d3d | |||
| bd3b0c712b | |||
| 46176c8029 | |||
| b798062501 | |||
| 63e88326a8 | |||
| 474301adc6 | |||
| 285300528b | |||
| 673f132151 | |||
| 8d0a96a8bf | |||
| cfa87e77a9 | |||
| 60e38e82ec | |||
| ce430fed4c | |||
| 6794e79bb4 | |||
| 181077b785 | |||
| 63635744bf | |||
| 2158c44efd | |||
| e6cf1c94a8 | |||
| d998cac319 | |||
| 6c84e26e70 | |||
| f4d61c168b | |||
| 8feb9e4656 | |||
| 25a1f1867f | |||
| 5e5c92663d | |||
| 942950f5b9 | |||
| d3687d3e81 | |||
| 43b8ecd172 | |||
| 606f57a3ab | |||
| 23b9d88a76 | |||
| c0b88018eb | |||
| fc4080c58a | |||
| 91b9495b04 | |||
| c2769dffe0 | |||
| 71e35311f5 | |||
| 97990e7ad5 | |||
| 73f39a7761 | |||
| 1ecfe68675 | |||
| 447594be28 | |||
| 9d1483c7e6 | |||
| 8e07f9ca56 | |||
| 57be18c026 | |||
| 99369b926c | |||
| 2633272ea9 | |||
| 2ba219fa4b | |||
| 9a423c3487 | |||
| 5479bb0e0c | |||
| c51e7b4af7 | |||
| 7d2c786acc | |||
| b72f522e30 | |||
| 352980311b | |||
| b411b979cb | |||
| ac739e485f | |||
| 8758e2e8d7 | |||
| 17e87478d2 | |||
| a5359e61e7 | |||
| 25b0ae7979 | |||
| dfe72b9d97 | |||
| 780ddd102b | |||
| 8cdbbcaaa2 | |||
| a2f0d14f29 | |||
| 2219695d92 | |||
| d23e9a9bed | |||
| add945e53c | |||
| c1ac32737d | |||
| 14b049d658 | |||
| 002c459981 | |||
| ce660a4413 | |||
| ee579af566 | |||
| caa944e752 | |||
| 00110fb3c3 | |||
| 3543b755af | |||
| 51185354dd | |||
| 9e845a6e53 | |||
| 00a0c56598 | |||
| 30da22e1c1 | |||
| e7d3f1f3ba | |||
| c1da1fdcd5 | |||
| f7c5d8a749 | |||
| 9cf7e2f0af | |||
| dd7921d514 | |||
| eb4f0348e1 | |||
| 38b4fd3737 | |||
| 36dd7a3e8d | |||
| dd698f6d5d | |||
| 06a7d19f98 | |||
| 3801532bd3 | |||
| aaacab7de7 | |||
| 4298c6fd9a | |||
| c30505dddd | |||
| 70e24d77a1 | |||
| fa3db2671a | |||
| 6fd9f2a0c5 | |||
| 1f72ce71b7 | |||
| 102a255575 | |||
| 5beb681c70 | |||
| c9a9db318e | |||
| 01e62c067b | |||
| ceb970c559 | |||
| 6894358fe1 | |||
| 3f0f4a04a9 | |||
| c564e1c3dc | |||
| 210d5ade1e | |||
| 33ebedc76d | |||
| 5b80654198 | |||
| 25e53f3c1a | |||
| 103f7b1ebc | |||
| a56937735e | |||
| 7148534401 | |||
| 4e91b0240b | |||
| 5e92a4ce5a | |||
| 471c663fdf | |||
| 64d333204b | |||
| c44af43840 | |||
| b117bbc125 | |||
| b59da08730 |
@@ -45,14 +45,35 @@ MINIMAX_API_KEY=
|
||||
MINIMAX_CN_API_KEY=
|
||||
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenCode Zen)
|
||||
# =============================================================================
|
||||
# OpenCode Zen provides curated, tested models (GPT, Claude, Gemini, MiniMax, GLM, Kimi)
|
||||
# Pay-as-you-go pricing. Get your key at: https://opencode.ai/auth
|
||||
OPENCODE_ZEN_API_KEY=
|
||||
# OPENCODE_ZEN_BASE_URL=https://opencode.ai/zen/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenCode Go)
|
||||
# =============================================================================
|
||||
# OpenCode Go provides access to open models (GLM-5, Kimi K2.5, MiniMax M2.5)
|
||||
# $10/month subscription. Get your key at: https://opencode.ai/auth
|
||||
OPENCODE_GO_API_KEY=
|
||||
# OPENCODE_GO_BASE_URL=https://opencode.ai/zen/go/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Parallel API Key - AI-native web search and extract
|
||||
# Get at: https://parallel.ai
|
||||
PARALLEL_API_KEY=
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
FIRECRAWL_API_KEY=
|
||||
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
@@ -5,7 +5,7 @@ Instructions for AI coding assistants and developers working on the hermes-agent
|
||||
## Development Environment
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # ALWAYS activate before running Python
|
||||
source venv/bin/activate # ALWAYS activate before running Python
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -23,6 +23,7 @@ hermes-agent/
|
||||
│ ├── prompt_caching.py # Anthropic prompt caching
|
||||
│ ├── auxiliary_client.py # Auxiliary LLM client (vision, summarization)
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
│ ├── models_dev.py # models.dev registry integration (provider-aware context)
|
||||
│ ├── display.py # KawaiiSpinner, tool preview formatting
|
||||
│ ├── skill_commands.py # Skill slash commands (shared CLI/gateway)
|
||||
│ └── trajectory.py # Trajectory saving helpers
|
||||
@@ -44,7 +45,7 @@ hermes-agent/
|
||||
│ ├── terminal_tool.py # Terminal orchestration
|
||||
│ ├── process_registry.py # Background process management
|
||||
│ ├── file_tools.py # File read/write/search/patch
|
||||
│ ├── web_tools.py # Firecrawl search/extract
|
||||
│ ├── web_tools.py # Web search/extract (Parallel + Firecrawl)
|
||||
│ ├── browser_tool.py # Browserbase browser automation
|
||||
│ ├── code_execution_tool.py # execute_code sandbox
|
||||
│ ├── delegate_tool.py # Subagent delegation
|
||||
@@ -129,14 +130,50 @@ Messages follow OpenAI format: `{"role": "system/user/assistant/tool", ...}`. Re
|
||||
- **KawaiiSpinner** (`agent/display.py`) — animated faces during API calls, `┊` activity feed for tool results
|
||||
- `load_cli_config()` in cli.py merges hardcoded defaults + user config YAML
|
||||
- **Skin engine** (`hermes_cli/skin_engine.py`) — data-driven CLI theming; initialized from `display.skin` config key at startup; skins customize banner colors, spinner faces/verbs/wings, tool prefix, response box, branding text
|
||||
- `process_command()` is a method on `HermesCLI` (not in commands.py)
|
||||
- `process_command()` is a method on `HermesCLI` — dispatches on canonical command name resolved via `resolve_command()` from the central registry
|
||||
- Skill slash commands: `agent/skill_commands.py` scans `~/.hermes/skills/`, injects as **user message** (not system prompt) to preserve prompt caching
|
||||
|
||||
### Adding CLI Commands
|
||||
### Slash Command Registry (`hermes_cli/commands.py`)
|
||||
|
||||
1. Add to `COMMANDS` dict in `hermes_cli/commands.py`
|
||||
2. Add handler in `HermesCLI.process_command()` in `cli.py`
|
||||
3. For persistent settings, use `save_config_value()` in `cli.py`
|
||||
All slash commands are defined in a central `COMMAND_REGISTRY` list of `CommandDef` objects. Every downstream consumer derives from this registry automatically:
|
||||
|
||||
- **CLI** — `process_command()` resolves aliases via `resolve_command()`, dispatches on canonical name
|
||||
- **Gateway** — `GATEWAY_KNOWN_COMMANDS` frozenset for hook emission, `resolve_command()` for dispatch
|
||||
- **Gateway help** — `gateway_help_lines()` generates `/help` output
|
||||
- **Telegram** — `telegram_bot_commands()` generates the BotCommand menu
|
||||
- **Slack** — `slack_subcommand_map()` generates `/hermes` subcommand routing
|
||||
- **Autocomplete** — `COMMANDS` flat dict feeds `SlashCommandCompleter`
|
||||
- **CLI help** — `COMMANDS_BY_CATEGORY` dict feeds `show_help()`
|
||||
|
||||
### Adding a Slash Command
|
||||
|
||||
1. Add a `CommandDef` entry to `COMMAND_REGISTRY` in `hermes_cli/commands.py`:
|
||||
```python
|
||||
CommandDef("mycommand", "Description of what it does", "Session",
|
||||
aliases=("mc",), args_hint="[arg]"),
|
||||
```
|
||||
2. Add handler in `HermesCLI.process_command()` in `cli.py`:
|
||||
```python
|
||||
elif canonical == "mycommand":
|
||||
self._handle_mycommand(cmd_original)
|
||||
```
|
||||
3. If the command is available in the gateway, add a handler in `gateway/run.py`:
|
||||
```python
|
||||
if canonical == "mycommand":
|
||||
return await self._handle_mycommand(event)
|
||||
```
|
||||
4. For persistent settings, use `save_config_value()` in `cli.py`
|
||||
|
||||
**CommandDef fields:**
|
||||
- `name` — canonical name without slash (e.g. `"background"`)
|
||||
- `description` — human-readable description
|
||||
- `category` — one of `"Session"`, `"Configuration"`, `"Tools & Skills"`, `"Info"`, `"Exit"`
|
||||
- `aliases` — tuple of alternative names (e.g. `("bg",)`)
|
||||
- `args_hint` — argument placeholder shown in help (e.g. `"<prompt>"`, `"[name]"`)
|
||||
- `cli_only` — only available in the interactive CLI
|
||||
- `gateway_only` — only available in messaging platforms
|
||||
|
||||
**Adding an alias** requires only adding it to the `aliases` tuple on the existing `CommandDef`. No other file changes needed — dispatch, help text, Telegram menu, Slack mapping, and autocomplete all update automatically.
|
||||
|
||||
---
|
||||
|
||||
@@ -235,6 +272,7 @@ hermes_cli/skin_engine.py # SkinConfig dataclass, built-in skins, YAML loader
|
||||
| Spinner verbs | `spinner.thinking_verbs` | `display.py` |
|
||||
| Spinner wings (optional) | `spinner.wings` | `display.py` |
|
||||
| Tool output prefix | `tool_prefix` | `display.py` |
|
||||
| Per-tool emojis | `tool_emojis` | `display.py` → `get_tool_emoji()` |
|
||||
| Agent name | `branding.agent_name` | `banner.py`, `cli.py` |
|
||||
| Welcome message | `branding.welcome` | `cli.py` |
|
||||
| Response box label | `branding.response_label` | `cli.py` |
|
||||
@@ -327,7 +365,10 @@ Rendering bugs in tmux/iTerm2 — ghosting on scroll. Use `curses` (stdlib) inst
|
||||
Leaks as literal `?[K` text under `prompt_toolkit`'s `patch_stdout`. Use space-padding: `f"\r{line}{' ' * pad}"`.
|
||||
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
When subagents overwrite this global, `execute_code` calls after delegation may fail with missing tool imports. Known bug.
|
||||
`_run_single_child()` in `delegate_tool.py` saves and restores this global around subagent execution. If you add new code that reads this global, be aware it may be temporarily stale during child agent runs.
|
||||
|
||||
### DO NOT hardcode cross-tool references in schema descriptions
|
||||
Tool schema descriptions must not mention tools from other toolsets by name (e.g., `browser_navigate` saying "prefer web_search"). Those tools may be unavailable (missing API keys, disabled toolset), causing the model to hallucinate calls to non-existent tools. If a cross-reference is needed, add it dynamically in `get_tool_definitions()` in `model_tools.py` — see the `browser_navigate` / `execute_code` post-processing blocks for the pattern.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests.
|
||||
@@ -337,7 +378,7 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~3000 tests, ~3 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
|
||||
+2
-2
@@ -136,7 +136,7 @@ hermes-agent/
|
||||
│ ├── auth.py # Provider resolution, OAuth, Nous Portal
|
||||
│ ├── models.py # OpenRouter model selection lists
|
||||
│ ├── banner.py # Welcome banner, ASCII art
|
||||
│ ├── commands.py # Slash command definitions + autocomplete
|
||||
│ ├── commands.py # Central slash command registry (CommandDef), autocomplete, gateway helpers
|
||||
│ ├── callbacks.py # Interactive callbacks (clarify, sudo, approval)
|
||||
│ ├── doctor.py # Diagnostics
|
||||
│ ├── skills_hub.py # Skills Hub CLI + /skills slash command
|
||||
@@ -147,7 +147,7 @@ hermes-agent/
|
||||
│ ├── approval.py # Dangerous command detection + per-session approval
|
||||
│ ├── terminal_tool.py # Terminal orchestration (sudo, env lifecycle, backends)
|
||||
│ ├── file_operations.py # read_file, write_file, search, patch, etc.
|
||||
│ ├── web_tools.py # web_search, web_extract (Firecrawl + Gemini summarization)
|
||||
│ ├── web_tools.py # web_search, web_extract (Parallel/Firecrawl + Gemini summarization)
|
||||
│ ├── vision_tools.py # Image analysis via multimodal models
|
||||
│ ├── delegate_tool.py # Subagent spawning and parallel task execution
|
||||
│ ├── code_execution_tool.py # Sandboxed Python with RPC tool access
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<img src="assets/banner.png" alt="Hermes Agent" width="100%">
|
||||
</p>
|
||||
|
||||
# Hermes Agent ⚕
|
||||
# Hermes Agent ☤
|
||||
|
||||
<p align="center">
|
||||
<a href="https://hermes-agent.nousresearch.com/docs/"><img src="https://img.shields.io/badge/Docs-hermes--agent.nousresearch.com-FFD700?style=for-the-badge" alt="Documentation"></a>
|
||||
@@ -62,6 +62,24 @@ hermes doctor # Diagnose any issues
|
||||
|
||||
📖 **[Full documentation →](https://hermes-agent.nousresearch.com/docs/)**
|
||||
|
||||
## CLI vs Messaging Quick Reference
|
||||
|
||||
Hermes has two entry points: start the terminal UI with `hermes`, or run the gateway and talk to it from Telegram, Discord, Slack, WhatsApp, Signal, or Email. Once you're in a conversation, many slash commands are shared across both interfaces.
|
||||
|
||||
| Action | CLI | Messaging platforms |
|
||||
|---------|-----|---------------------|
|
||||
| Start chatting | `hermes` | Run `hermes gateway setup` + `hermes gateway start`, then send the bot a message |
|
||||
| Start fresh conversation | `/new` or `/reset` | `/new` or `/reset` |
|
||||
| Change model | `/model [provider:model]` | `/model [provider:model]` |
|
||||
| Set a personality | `/personality [name]` | `/personality [name]` |
|
||||
| Retry or undo the last turn | `/retry`, `/undo` | `/retry`, `/undo` |
|
||||
| Compress context / check usage | `/compress`, `/usage`, `/insights [--days N]` | `/compress`, `/usage`, `/insights [days]` |
|
||||
| Browse skills | `/skills` or `/<skill-name>` | `/skills` or `/<skill-name>` |
|
||||
| Interrupt current work | `Ctrl+C` or send a new message | `/stop` or send a new message |
|
||||
| Platform-specific status | `/platforms` | `/status`, `/sethome` |
|
||||
|
||||
For the full command lists, see the [CLI guide](https://hermes-agent.nousresearch.com/docs/user-guide/cli) and the [Messaging Gateway guide](https://hermes-agent.nousresearch.com/docs/user-guide/messaging).
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
@@ -128,8 +146,8 @@ git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
git submodule update --init mini-swe-agent # required terminal backend
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
python -m pytest tests/ -q
|
||||
|
||||
@@ -0,0 +1,377 @@
|
||||
# Hermes Agent v0.3.0 (v2026.3.17)
|
||||
|
||||
**Release Date:** March 17, 2026
|
||||
|
||||
> The streaming, plugins, and provider release — unified real-time token delivery, first-class plugin architecture, rebuilt provider system with Vercel AI Gateway, native Anthropic provider, smart approvals, live Chrome CDP browser connect, ACP IDE integration, Honcho memory, voice mode, persistent shell, and 50+ bug fixes across every platform.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
- **Unified Streaming Infrastructure** — Real-time token-by-token delivery in CLI and all gateway platforms. Responses stream as they're generated instead of arriving as a block. ([#1538](https://github.com/NousResearch/hermes-agent/pull/1538))
|
||||
|
||||
- **First-Class Plugin Architecture** — Drop Python files into `~/.hermes/plugins/` to extend Hermes with custom tools, commands, and hooks. No forking required. ([#1544](https://github.com/NousResearch/hermes-agent/pull/1544), [#1555](https://github.com/NousResearch/hermes-agent/pull/1555))
|
||||
|
||||
- **Native Anthropic Provider** — Direct Anthropic API calls with Claude Code credential auto-discovery, OAuth PKCE flows, and native prompt caching. No OpenRouter middleman needed. ([#1097](https://github.com/NousResearch/hermes-agent/pull/1097))
|
||||
|
||||
- **Smart Approvals + /stop Command** — Codex-inspired approval system that learns which commands are safe and remembers your preferences. `/stop` kills the current agent run immediately. ([#1543](https://github.com/NousResearch/hermes-agent/pull/1543))
|
||||
|
||||
- **Honcho Memory Integration** — Async memory writes, configurable recall modes, session title integration, and multi-user isolation in gateway mode. By @erosika. ([#736](https://github.com/NousResearch/hermes-agent/pull/736))
|
||||
|
||||
- **Voice Mode** — Push-to-talk in CLI, voice notes in Telegram/Discord, Discord voice channel support, and local Whisper transcription via faster-whisper. ([#1299](https://github.com/NousResearch/hermes-agent/pull/1299), [#1185](https://github.com/NousResearch/hermes-agent/pull/1185), [#1429](https://github.com/NousResearch/hermes-agent/pull/1429))
|
||||
|
||||
- **Concurrent Tool Execution** — Multiple independent tool calls now run in parallel via ThreadPoolExecutor, significantly reducing latency for multi-tool turns. ([#1152](https://github.com/NousResearch/hermes-agent/pull/1152))
|
||||
|
||||
- **PII Redaction** — When `privacy.redact_pii` is enabled, personally identifiable information is automatically scrubbed before sending context to LLM providers. ([#1542](https://github.com/NousResearch/hermes-agent/pull/1542))
|
||||
|
||||
- **`/browser connect` via CDP** — Attach browser tools to a live Chrome instance through Chrome DevTools Protocol. Debug, inspect, and interact with pages you already have open. ([#1549](https://github.com/NousResearch/hermes-agent/pull/1549))
|
||||
|
||||
- **Vercel AI Gateway Provider** — Route Hermes through Vercel's AI Gateway for access to their model catalog and infrastructure. ([#1628](https://github.com/NousResearch/hermes-agent/pull/1628))
|
||||
|
||||
- **Centralized Provider Router** — Rebuilt provider system with `call_llm` API, unified `/model` command, auto-detect provider on model switch, and direct endpoint overrides for auxiliary/delegation clients. ([#1003](https://github.com/NousResearch/hermes-agent/pull/1003), [#1506](https://github.com/NousResearch/hermes-agent/pull/1506), [#1375](https://github.com/NousResearch/hermes-agent/pull/1375))
|
||||
|
||||
- **ACP Server (IDE Integration)** — VS Code, Zed, and JetBrains can now connect to Hermes as an agent backend, with full slash command support. ([#1254](https://github.com/NousResearch/hermes-agent/pull/1254), [#1532](https://github.com/NousResearch/hermes-agent/pull/1532))
|
||||
|
||||
- **Persistent Shell Mode** — Local and SSH terminal backends can maintain shell state across tool calls — cd, env vars, and aliases persist. By @alt-glitch. ([#1067](https://github.com/NousResearch/hermes-agent/pull/1067), [#1483](https://github.com/NousResearch/hermes-agent/pull/1483))
|
||||
|
||||
- **Agentic On-Policy Distillation (OPD)** — New RL training environment for distilling agent policies, expanding the Atropos training ecosystem. ([#1149](https://github.com/NousResearch/hermes-agent/pull/1149))
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Core Agent & Architecture
|
||||
|
||||
### Provider & Model Support
|
||||
- **Centralized provider router** with `call_llm` API and unified `/model` command — switch models and providers seamlessly ([#1003](https://github.com/NousResearch/hermes-agent/pull/1003))
|
||||
- **Vercel AI Gateway** provider support ([#1628](https://github.com/NousResearch/hermes-agent/pull/1628))
|
||||
- **Auto-detect provider** when switching models via `/model` ([#1506](https://github.com/NousResearch/hermes-agent/pull/1506))
|
||||
- **Direct endpoint overrides** for auxiliary and delegation clients — point vision/subagent calls at specific endpoints ([#1375](https://github.com/NousResearch/hermes-agent/pull/1375))
|
||||
- **Native Anthropic auxiliary vision** — use Claude's native vision API instead of routing through OpenAI-compatible endpoints ([#1377](https://github.com/NousResearch/hermes-agent/pull/1377))
|
||||
- Anthropic OAuth flow improvements — auto-run `claude setup-token`, reauthentication, PKCE state persistence, identity fingerprinting ([#1132](https://github.com/NousResearch/hermes-agent/pull/1132), [#1360](https://github.com/NousResearch/hermes-agent/pull/1360), [#1396](https://github.com/NousResearch/hermes-agent/pull/1396), [#1597](https://github.com/NousResearch/hermes-agent/pull/1597))
|
||||
- Fix adaptive thinking without `budget_tokens` for Claude 4.6 models — by @ASRagab ([#1128](https://github.com/NousResearch/hermes-agent/pull/1128))
|
||||
- Fix Anthropic cache markers through adapter — by @brandtcormorant ([#1216](https://github.com/NousResearch/hermes-agent/pull/1216))
|
||||
- Retry Anthropic 429/529 errors and surface details to users — by @0xbyt4 ([#1585](https://github.com/NousResearch/hermes-agent/pull/1585))
|
||||
- Fix Anthropic adapter max_tokens, fallback crash, proxy base_url — by @0xbyt4 ([#1121](https://github.com/NousResearch/hermes-agent/pull/1121))
|
||||
- Fix DeepSeek V3 parser dropping multiple parallel tool calls — by @mr-emmett-one ([#1365](https://github.com/NousResearch/hermes-agent/pull/1365), [#1300](https://github.com/NousResearch/hermes-agent/pull/1300))
|
||||
- Accept unlisted models with warning instead of rejecting ([#1047](https://github.com/NousResearch/hermes-agent/pull/1047), [#1102](https://github.com/NousResearch/hermes-agent/pull/1102))
|
||||
- Skip reasoning params for unsupported OpenRouter models ([#1485](https://github.com/NousResearch/hermes-agent/pull/1485))
|
||||
- MiniMax Anthropic API compatibility fix ([#1623](https://github.com/NousResearch/hermes-agent/pull/1623))
|
||||
- Custom endpoint `/models` verification and `/v1` base URL suggestion ([#1480](https://github.com/NousResearch/hermes-agent/pull/1480))
|
||||
- Resolve delegation providers from `custom_providers` config ([#1328](https://github.com/NousResearch/hermes-agent/pull/1328))
|
||||
- Kimi model additions and User-Agent fix ([#1039](https://github.com/NousResearch/hermes-agent/pull/1039))
|
||||
- Strip `call_id`/`response_item_id` for Mistral compatibility ([#1058](https://github.com/NousResearch/hermes-agent/pull/1058))
|
||||
|
||||
### Agent Loop & Conversation
|
||||
- **Anthropic Context Editing API** support ([#1147](https://github.com/NousResearch/hermes-agent/pull/1147))
|
||||
- Improved context compaction handoff summaries — compressor now preserves more actionable state ([#1273](https://github.com/NousResearch/hermes-agent/pull/1273))
|
||||
- Sync session_id after mid-run context compression ([#1160](https://github.com/NousResearch/hermes-agent/pull/1160))
|
||||
- Session hygiene threshold tuned to 50% for more proactive compression ([#1096](https://github.com/NousResearch/hermes-agent/pull/1096), [#1161](https://github.com/NousResearch/hermes-agent/pull/1161))
|
||||
- Include session ID in system prompt via `--pass-session-id` flag ([#1040](https://github.com/NousResearch/hermes-agent/pull/1040))
|
||||
- Prevent closed OpenAI client reuse across retries ([#1391](https://github.com/NousResearch/hermes-agent/pull/1391))
|
||||
- Sanitize chat payloads and provider precedence ([#1253](https://github.com/NousResearch/hermes-agent/pull/1253))
|
||||
- Handle dict tool call arguments from Codex and local backends ([#1393](https://github.com/NousResearch/hermes-agent/pull/1393), [#1440](https://github.com/NousResearch/hermes-agent/pull/1440))
|
||||
|
||||
### Memory & Sessions
|
||||
- **Improve memory prioritization** — user preferences and corrections weighted above procedural knowledge ([#1548](https://github.com/NousResearch/hermes-agent/pull/1548))
|
||||
- Tighter memory and session recall guidance in system prompts ([#1329](https://github.com/NousResearch/hermes-agent/pull/1329))
|
||||
- Persist CLI token counts to session DB for `/insights` ([#1498](https://github.com/NousResearch/hermes-agent/pull/1498))
|
||||
- Keep Honcho recall out of the cached system prefix ([#1201](https://github.com/NousResearch/hermes-agent/pull/1201))
|
||||
- Correct `seed_ai_identity` to use `session.add_messages()` ([#1475](https://github.com/NousResearch/hermes-agent/pull/1475))
|
||||
- Isolate Honcho session routing for multi-user gateway ([#1500](https://github.com/NousResearch/hermes-agent/pull/1500))
|
||||
|
||||
---
|
||||
|
||||
## 📱 Messaging Platforms (Gateway)
|
||||
|
||||
### Gateway Core
|
||||
- **System gateway service mode** — run as a system-level systemd service, not just user-level ([#1371](https://github.com/NousResearch/hermes-agent/pull/1371))
|
||||
- **Gateway install scope prompts** — choose user vs system scope during setup ([#1374](https://github.com/NousResearch/hermes-agent/pull/1374))
|
||||
- **Reasoning hot reload** — change reasoning settings without restarting the gateway ([#1275](https://github.com/NousResearch/hermes-agent/pull/1275))
|
||||
- Default group sessions to per-user isolation — no more shared state across users in group chats ([#1495](https://github.com/NousResearch/hermes-agent/pull/1495), [#1417](https://github.com/NousResearch/hermes-agent/pull/1417))
|
||||
- Harden gateway restart recovery ([#1310](https://github.com/NousResearch/hermes-agent/pull/1310))
|
||||
- Cancel active runs during shutdown ([#1427](https://github.com/NousResearch/hermes-agent/pull/1427))
|
||||
- SSL certificate auto-detection for NixOS and non-standard systems ([#1494](https://github.com/NousResearch/hermes-agent/pull/1494))
|
||||
- Auto-detect D-Bus session bus for `systemctl --user` on headless servers ([#1601](https://github.com/NousResearch/hermes-agent/pull/1601))
|
||||
- Auto-enable systemd linger during gateway install on headless servers ([#1334](https://github.com/NousResearch/hermes-agent/pull/1334))
|
||||
- Fall back to module entrypoint when `hermes` is not on PATH ([#1355](https://github.com/NousResearch/hermes-agent/pull/1355))
|
||||
- Fix dual gateways on macOS launchd after `hermes update` ([#1567](https://github.com/NousResearch/hermes-agent/pull/1567))
|
||||
- Remove recursive ExecStop from systemd units ([#1530](https://github.com/NousResearch/hermes-agent/pull/1530))
|
||||
- Prevent logging handler accumulation in gateway mode ([#1251](https://github.com/NousResearch/hermes-agent/pull/1251))
|
||||
- Restart on retryable startup failures — by @jplew ([#1517](https://github.com/NousResearch/hermes-agent/pull/1517))
|
||||
- Backfill model on gateway sessions after agent runs ([#1306](https://github.com/NousResearch/hermes-agent/pull/1306))
|
||||
- PID-based gateway kill and deferred config write ([#1499](https://github.com/NousResearch/hermes-agent/pull/1499))
|
||||
|
||||
### Telegram
|
||||
- Buffer media groups to prevent self-interruption from photo bursts ([#1341](https://github.com/NousResearch/hermes-agent/pull/1341), [#1422](https://github.com/NousResearch/hermes-agent/pull/1422))
|
||||
- Retry on transient TLS failures during connect and send ([#1535](https://github.com/NousResearch/hermes-agent/pull/1535))
|
||||
- Harden polling conflict handling ([#1339](https://github.com/NousResearch/hermes-agent/pull/1339))
|
||||
- Escape chunk indicators and inline code in MarkdownV2 ([#1478](https://github.com/NousResearch/hermes-agent/pull/1478), [#1626](https://github.com/NousResearch/hermes-agent/pull/1626))
|
||||
- Check updater/app state before disconnect ([#1389](https://github.com/NousResearch/hermes-agent/pull/1389))
|
||||
|
||||
### Discord
|
||||
- `/thread` command with `auto_thread` config and media metadata fixes ([#1178](https://github.com/NousResearch/hermes-agent/pull/1178))
|
||||
- Auto-thread on @mention, skip mention text in bot threads ([#1438](https://github.com/NousResearch/hermes-agent/pull/1438))
|
||||
- Retry without reply reference for system messages ([#1385](https://github.com/NousResearch/hermes-agent/pull/1385))
|
||||
- Preserve native document and video attachment support ([#1392](https://github.com/NousResearch/hermes-agent/pull/1392))
|
||||
- Defer discord adapter annotations to avoid optional import crashes ([#1314](https://github.com/NousResearch/hermes-agent/pull/1314))
|
||||
|
||||
### Slack
|
||||
- Thread handling overhaul — progress messages, responses, and session isolation all respect threads ([#1103](https://github.com/NousResearch/hermes-agent/pull/1103))
|
||||
- Formatting, reactions, user resolution, and command improvements ([#1106](https://github.com/NousResearch/hermes-agent/pull/1106))
|
||||
- Fix MAX_MESSAGE_LENGTH 3900 → 39000 ([#1117](https://github.com/NousResearch/hermes-agent/pull/1117))
|
||||
- File upload fallback preserves thread context — by @0xbyt4 ([#1122](https://github.com/NousResearch/hermes-agent/pull/1122))
|
||||
- Improve setup guidance ([#1387](https://github.com/NousResearch/hermes-agent/pull/1387))
|
||||
|
||||
### Email
|
||||
- Fix IMAP UID tracking and SMTP TLS verification ([#1305](https://github.com/NousResearch/hermes-agent/pull/1305))
|
||||
- Add `skip_attachments` option via config.yaml ([#1536](https://github.com/NousResearch/hermes-agent/pull/1536))
|
||||
|
||||
### Home Assistant
|
||||
- Event filtering closed by default ([#1169](https://github.com/NousResearch/hermes-agent/pull/1169))
|
||||
|
||||
---
|
||||
|
||||
## 🖥️ CLI & User Experience
|
||||
|
||||
### Interactive CLI
|
||||
- **Persistent CLI status bar** — always-visible model, provider, and token counts ([#1522](https://github.com/NousResearch/hermes-agent/pull/1522))
|
||||
- **File path autocomplete** in the input prompt ([#1545](https://github.com/NousResearch/hermes-agent/pull/1545))
|
||||
- **`/plan` command** — generate implementation plans from specs ([#1372](https://github.com/NousResearch/hermes-agent/pull/1372), [#1381](https://github.com/NousResearch/hermes-agent/pull/1381))
|
||||
- **Major `/rollback` improvements** — richer checkpoint history, clearer UX ([#1505](https://github.com/NousResearch/hermes-agent/pull/1505))
|
||||
- **Preload CLI skills on launch** — skills are ready before the first prompt ([#1359](https://github.com/NousResearch/hermes-agent/pull/1359))
|
||||
- **Centralized slash command registry** — all commands defined once, consumed everywhere ([#1603](https://github.com/NousResearch/hermes-agent/pull/1603))
|
||||
- `/bg` alias for `/background` ([#1590](https://github.com/NousResearch/hermes-agent/pull/1590))
|
||||
- Prefix matching for slash commands — `/mod` resolves to `/model` ([#1320](https://github.com/NousResearch/hermes-agent/pull/1320))
|
||||
- `/new`, `/reset`, `/clear` now start genuinely fresh sessions ([#1237](https://github.com/NousResearch/hermes-agent/pull/1237))
|
||||
- Accept session ID prefixes for session actions ([#1425](https://github.com/NousResearch/hermes-agent/pull/1425))
|
||||
- TUI prompt and accent output now respect active skin ([#1282](https://github.com/NousResearch/hermes-agent/pull/1282))
|
||||
- Centralize tool emoji metadata in registry + skin integration ([#1484](https://github.com/NousResearch/hermes-agent/pull/1484))
|
||||
- "View full command" option added to dangerous command approval — by @teknium1 based on design by community ([#887](https://github.com/NousResearch/hermes-agent/pull/887))
|
||||
- Non-blocking startup update check and banner deduplication ([#1386](https://github.com/NousResearch/hermes-agent/pull/1386))
|
||||
- `/reasoning` command output ordering and inline think extraction fixes ([#1031](https://github.com/NousResearch/hermes-agent/pull/1031))
|
||||
- Verbose mode shows full untruncated output ([#1472](https://github.com/NousResearch/hermes-agent/pull/1472))
|
||||
- Fix `/status` to report live state and tokens ([#1476](https://github.com/NousResearch/hermes-agent/pull/1476))
|
||||
- Seed a default global SOUL.md ([#1311](https://github.com/NousResearch/hermes-agent/pull/1311))
|
||||
|
||||
### Setup & Configuration
|
||||
- **OpenClaw migration** during first-time setup — by @kshitijk4poor ([#981](https://github.com/NousResearch/hermes-agent/pull/981))
|
||||
- `hermes claw migrate` command + migration docs ([#1059](https://github.com/NousResearch/hermes-agent/pull/1059))
|
||||
- Smart vision setup that respects the user's chosen provider ([#1323](https://github.com/NousResearch/hermes-agent/pull/1323))
|
||||
- Handle headless setup flows end-to-end ([#1274](https://github.com/NousResearch/hermes-agent/pull/1274))
|
||||
- Prefer curses over `simple_term_menu` in setup.py ([#1487](https://github.com/NousResearch/hermes-agent/pull/1487))
|
||||
- Show effective model and provider in `/status` ([#1284](https://github.com/NousResearch/hermes-agent/pull/1284))
|
||||
- Config set examples use placeholder syntax ([#1322](https://github.com/NousResearch/hermes-agent/pull/1322))
|
||||
- Reload .env over stale shell overrides ([#1434](https://github.com/NousResearch/hermes-agent/pull/1434))
|
||||
- Fix is_coding_plan NameError crash — by @0xbyt4 ([#1123](https://github.com/NousResearch/hermes-agent/pull/1123))
|
||||
- Add missing packages to setuptools config — by @alt-glitch ([#912](https://github.com/NousResearch/hermes-agent/pull/912))
|
||||
- Installer: clarify why sudo is needed at every prompt ([#1602](https://github.com/NousResearch/hermes-agent/pull/1602))
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Tool System
|
||||
|
||||
### Terminal & Execution
|
||||
- **Persistent shell mode** for local and SSH backends — maintain shell state across tool calls — by @alt-glitch ([#1067](https://github.com/NousResearch/hermes-agent/pull/1067), [#1483](https://github.com/NousResearch/hermes-agent/pull/1483))
|
||||
- **Tirith pre-exec command scanning** — security layer that analyzes commands before execution ([#1256](https://github.com/NousResearch/hermes-agent/pull/1256))
|
||||
- Strip Hermes provider env vars from all subprocess environments ([#1157](https://github.com/NousResearch/hermes-agent/pull/1157), [#1172](https://github.com/NousResearch/hermes-agent/pull/1172), [#1399](https://github.com/NousResearch/hermes-agent/pull/1399), [#1419](https://github.com/NousResearch/hermes-agent/pull/1419)) — initial fix by @eren-karakus0
|
||||
- SSH preflight check ([#1486](https://github.com/NousResearch/hermes-agent/pull/1486))
|
||||
- Docker backend: make cwd workspace mount explicit opt-in ([#1534](https://github.com/NousResearch/hermes-agent/pull/1534))
|
||||
- Add project root to PYTHONPATH in execute_code sandbox ([#1383](https://github.com/NousResearch/hermes-agent/pull/1383))
|
||||
- Eliminate execute_code progress spam on gateway platforms ([#1098](https://github.com/NousResearch/hermes-agent/pull/1098))
|
||||
- Clearer docker backend preflight errors ([#1276](https://github.com/NousResearch/hermes-agent/pull/1276))
|
||||
|
||||
### Browser
|
||||
- **`/browser connect`** — attach browser tools to a live Chrome instance via CDP ([#1549](https://github.com/NousResearch/hermes-agent/pull/1549))
|
||||
- Improve browser cleanup, local browser PATH setup, and screenshot recovery ([#1333](https://github.com/NousResearch/hermes-agent/pull/1333))
|
||||
|
||||
### MCP
|
||||
- **Selective tool loading** with utility policies — filter which MCP tools are available ([#1302](https://github.com/NousResearch/hermes-agent/pull/1302))
|
||||
- Auto-reload MCP tools when `mcp_servers` config changes without restart ([#1474](https://github.com/NousResearch/hermes-agent/pull/1474))
|
||||
- Resolve npx stdio connection failures ([#1291](https://github.com/NousResearch/hermes-agent/pull/1291))
|
||||
- Preserve MCP toolsets when saving platform tool config ([#1421](https://github.com/NousResearch/hermes-agent/pull/1421))
|
||||
|
||||
### Vision
|
||||
- Unify vision backend gating ([#1367](https://github.com/NousResearch/hermes-agent/pull/1367))
|
||||
- Surface actual error reason instead of generic message ([#1338](https://github.com/NousResearch/hermes-agent/pull/1338))
|
||||
- Make Claude image handling work end-to-end ([#1408](https://github.com/NousResearch/hermes-agent/pull/1408))
|
||||
|
||||
### Cron
|
||||
- **Compress cron management into one tool** — single `cronjob` tool replaces multiple commands ([#1343](https://github.com/NousResearch/hermes-agent/pull/1343))
|
||||
- Suppress duplicate cron sends to auto-delivery targets ([#1357](https://github.com/NousResearch/hermes-agent/pull/1357))
|
||||
- Persist cron sessions to SQLite ([#1255](https://github.com/NousResearch/hermes-agent/pull/1255))
|
||||
- Per-job runtime overrides (provider, model, base_url) ([#1398](https://github.com/NousResearch/hermes-agent/pull/1398))
|
||||
- Atomic write in `save_job_output` to prevent data loss on crash ([#1173](https://github.com/NousResearch/hermes-agent/pull/1173))
|
||||
- Preserve thread context for `deliver=origin` ([#1437](https://github.com/NousResearch/hermes-agent/pull/1437))
|
||||
|
||||
### Patch Tool
|
||||
- Avoid corrupting pipe chars in V4A patch apply ([#1286](https://github.com/NousResearch/hermes-agent/pull/1286))
|
||||
- Permissive `block_anchor` thresholds and unicode normalization ([#1539](https://github.com/NousResearch/hermes-agent/pull/1539))
|
||||
|
||||
### Delegation
|
||||
- Add observability metadata to subagent results (model, tokens, duration, tool trace) ([#1175](https://github.com/NousResearch/hermes-agent/pull/1175))
|
||||
|
||||
---
|
||||
|
||||
## 🧩 Skills Ecosystem
|
||||
|
||||
### Skills System
|
||||
- **Integrate skills.sh** as a hub source alongside ClawHub ([#1303](https://github.com/NousResearch/hermes-agent/pull/1303))
|
||||
- Secure skill env setup on load ([#1153](https://github.com/NousResearch/hermes-agent/pull/1153))
|
||||
- Honor policy table for dangerous verdicts ([#1330](https://github.com/NousResearch/hermes-agent/pull/1330))
|
||||
- Harden ClawHub skill search exact matches ([#1400](https://github.com/NousResearch/hermes-agent/pull/1400))
|
||||
- Fix ClawHub skill install — use `/download` ZIP endpoint ([#1060](https://github.com/NousResearch/hermes-agent/pull/1060))
|
||||
- Avoid mislabeling local skills as builtin — by @arceus77-7 ([#862](https://github.com/NousResearch/hermes-agent/pull/862))
|
||||
|
||||
### New Skills
|
||||
- **Linear** project management ([#1230](https://github.com/NousResearch/hermes-agent/pull/1230))
|
||||
- **X/Twitter** via x-cli ([#1285](https://github.com/NousResearch/hermes-agent/pull/1285))
|
||||
- **Telephony** — Twilio, SMS, and AI calls ([#1289](https://github.com/NousResearch/hermes-agent/pull/1289))
|
||||
- **1Password** — by @arceus77-7 ([#883](https://github.com/NousResearch/hermes-agent/pull/883), [#1179](https://github.com/NousResearch/hermes-agent/pull/1179))
|
||||
- **NeuroSkill BCI** integration ([#1135](https://github.com/NousResearch/hermes-agent/pull/1135))
|
||||
- **Blender MCP** for 3D modeling ([#1531](https://github.com/NousResearch/hermes-agent/pull/1531))
|
||||
- **OSS Security Forensics** ([#1482](https://github.com/NousResearch/hermes-agent/pull/1482))
|
||||
- **Parallel CLI** research skill ([#1301](https://github.com/NousResearch/hermes-agent/pull/1301))
|
||||
- **OpenCode** CLI skill ([#1174](https://github.com/NousResearch/hermes-agent/pull/1174))
|
||||
- **ASCII Video** skill refactored — by @SHL0MS ([#1213](https://github.com/NousResearch/hermes-agent/pull/1213), [#1598](https://github.com/NousResearch/hermes-agent/pull/1598))
|
||||
|
||||
---
|
||||
|
||||
## 🎙️ Voice Mode
|
||||
|
||||
- Voice mode foundation — push-to-talk CLI, Telegram/Discord voice notes ([#1299](https://github.com/NousResearch/hermes-agent/pull/1299))
|
||||
- Free local Whisper transcription via faster-whisper ([#1185](https://github.com/NousResearch/hermes-agent/pull/1185))
|
||||
- Discord voice channel reliability fixes ([#1429](https://github.com/NousResearch/hermes-agent/pull/1429))
|
||||
- Restore local STT fallback for gateway voice notes ([#1490](https://github.com/NousResearch/hermes-agent/pull/1490))
|
||||
- Honor `stt.enabled: false` across gateway transcription ([#1394](https://github.com/NousResearch/hermes-agent/pull/1394))
|
||||
- Fix bogus incapability message on Telegram voice notes (Issue [#1033](https://github.com/NousResearch/hermes-agent/issues/1033))
|
||||
|
||||
---
|
||||
|
||||
## 🔌 ACP (IDE Integration)
|
||||
|
||||
- Restore ACP server implementation ([#1254](https://github.com/NousResearch/hermes-agent/pull/1254))
|
||||
- Support slash commands in ACP adapter ([#1532](https://github.com/NousResearch/hermes-agent/pull/1532))
|
||||
|
||||
---
|
||||
|
||||
## 🧪 RL Training
|
||||
|
||||
- **Agentic On-Policy Distillation (OPD)** environment — new RL training environment for agent policy distillation ([#1149](https://github.com/NousResearch/hermes-agent/pull/1149))
|
||||
- Make tinker-atropos RL training fully optional ([#1062](https://github.com/NousResearch/hermes-agent/pull/1062))
|
||||
|
||||
---
|
||||
|
||||
## 🔒 Security & Reliability
|
||||
|
||||
### Security Hardening
|
||||
- **Tirith pre-exec command scanning** — static analysis of terminal commands before execution ([#1256](https://github.com/NousResearch/hermes-agent/pull/1256))
|
||||
- **PII redaction** when `privacy.redact_pii` is enabled ([#1542](https://github.com/NousResearch/hermes-agent/pull/1542))
|
||||
- Strip Hermes provider/gateway/tool env vars from all subprocess environments ([#1157](https://github.com/NousResearch/hermes-agent/pull/1157), [#1172](https://github.com/NousResearch/hermes-agent/pull/1172), [#1399](https://github.com/NousResearch/hermes-agent/pull/1399), [#1419](https://github.com/NousResearch/hermes-agent/pull/1419))
|
||||
- Docker cwd workspace mount now explicit opt-in — never auto-mount host directories ([#1534](https://github.com/NousResearch/hermes-agent/pull/1534))
|
||||
- Escape parens and braces in fork bomb regex pattern ([#1397](https://github.com/NousResearch/hermes-agent/pull/1397))
|
||||
- Harden `.worktreeinclude` path containment ([#1388](https://github.com/NousResearch/hermes-agent/pull/1388))
|
||||
- Use description as `pattern_key` to prevent approval collisions ([#1395](https://github.com/NousResearch/hermes-agent/pull/1395))
|
||||
|
||||
### Reliability
|
||||
- Guard init-time stdio writes ([#1271](https://github.com/NousResearch/hermes-agent/pull/1271))
|
||||
- Session log writes reuse shared atomic JSON helper ([#1280](https://github.com/NousResearch/hermes-agent/pull/1280))
|
||||
- Atomic temp cleanup protected on interrupts ([#1401](https://github.com/NousResearch/hermes-agent/pull/1401))
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Notable Bug Fixes
|
||||
|
||||
- **`/status` always showing 0 tokens** — now reports live state (Issue [#1465](https://github.com/NousResearch/hermes-agent/issues/1465), [#1476](https://github.com/NousResearch/hermes-agent/pull/1476))
|
||||
- **Custom model endpoints not working** — restored config-saved endpoint resolution (Issue [#1460](https://github.com/NousResearch/hermes-agent/issues/1460), [#1373](https://github.com/NousResearch/hermes-agent/pull/1373))
|
||||
- **MCP tools not visible until restart** — auto-reload on config change (Issue [#1036](https://github.com/NousResearch/hermes-agent/issues/1036), [#1474](https://github.com/NousResearch/hermes-agent/pull/1474))
|
||||
- **`hermes tools` removing MCP tools** — preserve MCP toolsets when saving (Issue [#1247](https://github.com/NousResearch/hermes-agent/issues/1247), [#1421](https://github.com/NousResearch/hermes-agent/pull/1421))
|
||||
- **Terminal subprocesses inheriting `OPENAI_BASE_URL`** breaking external tools (Issue [#1002](https://github.com/NousResearch/hermes-agent/issues/1002), [#1399](https://github.com/NousResearch/hermes-agent/pull/1399))
|
||||
- **Background process lost on gateway restart** — improved recovery (Issue [#1144](https://github.com/NousResearch/hermes-agent/issues/1144))
|
||||
- **Cron jobs not persisting state** — now stored in SQLite (Issue [#1416](https://github.com/NousResearch/hermes-agent/issues/1416), [#1255](https://github.com/NousResearch/hermes-agent/pull/1255))
|
||||
- **Cronjob `deliver: origin` not preserving thread context** (Issue [#1219](https://github.com/NousResearch/hermes-agent/issues/1219), [#1437](https://github.com/NousResearch/hermes-agent/pull/1437))
|
||||
- **Gateway systemd service failing to auto-restart** when browser processes orphaned (Issue [#1617](https://github.com/NousResearch/hermes-agent/issues/1617))
|
||||
- **`/background` completion report cut off in Telegram** (Issue [#1443](https://github.com/NousResearch/hermes-agent/issues/1443))
|
||||
- **Model switching not taking effect** (Issue [#1244](https://github.com/NousResearch/hermes-agent/issues/1244), [#1183](https://github.com/NousResearch/hermes-agent/pull/1183))
|
||||
- **`hermes doctor` reporting cronjob as unavailable** (Issue [#878](https://github.com/NousResearch/hermes-agent/issues/878), [#1180](https://github.com/NousResearch/hermes-agent/pull/1180))
|
||||
- **WhatsApp bridge messages not received** from mobile (Issue [#1142](https://github.com/NousResearch/hermes-agent/issues/1142))
|
||||
- **Setup wizard hanging on headless SSH** (Issue [#905](https://github.com/NousResearch/hermes-agent/issues/905), [#1274](https://github.com/NousResearch/hermes-agent/pull/1274))
|
||||
- **Log handler accumulation** degrading gateway performance (Issue [#990](https://github.com/NousResearch/hermes-agent/issues/990), [#1251](https://github.com/NousResearch/hermes-agent/pull/1251))
|
||||
- **Gateway NULL model in DB** (Issue [#987](https://github.com/NousResearch/hermes-agent/issues/987), [#1306](https://github.com/NousResearch/hermes-agent/pull/1306))
|
||||
- **Strict endpoints rejecting replayed tool_calls** (Issue [#893](https://github.com/NousResearch/hermes-agent/issues/893))
|
||||
- **Remaining hardcoded `~/.hermes` paths** — all now respect `HERMES_HOME` (Issue [#892](https://github.com/NousResearch/hermes-agent/issues/892), [#1233](https://github.com/NousResearch/hermes-agent/pull/1233))
|
||||
- **Delegate tool not working with custom inference providers** (Issue [#1011](https://github.com/NousResearch/hermes-agent/issues/1011), [#1328](https://github.com/NousResearch/hermes-agent/pull/1328))
|
||||
- **Skills Guard blocking official skills** (Issue [#1006](https://github.com/NousResearch/hermes-agent/issues/1006), [#1330](https://github.com/NousResearch/hermes-agent/pull/1330))
|
||||
- **Setup writing provider before model selection** (Issue [#1182](https://github.com/NousResearch/hermes-agent/issues/1182))
|
||||
- **`GatewayConfig.get()` AttributeError** crashing all message handling (Issue [#1158](https://github.com/NousResearch/hermes-agent/issues/1158), [#1287](https://github.com/NousResearch/hermes-agent/pull/1287))
|
||||
- **`/update` hard-failing with "command not found"** (Issue [#1049](https://github.com/NousResearch/hermes-agent/issues/1049))
|
||||
- **Image analysis failing silently** (Issue [#1034](https://github.com/NousResearch/hermes-agent/issues/1034), [#1338](https://github.com/NousResearch/hermes-agent/pull/1338))
|
||||
- **API `BadRequestError` from `'dict'` object has no attribute `'strip'`** (Issue [#1071](https://github.com/NousResearch/hermes-agent/issues/1071))
|
||||
- **Slash commands requiring exact full name** — now uses prefix matching (Issue [#928](https://github.com/NousResearch/hermes-agent/issues/928), [#1320](https://github.com/NousResearch/hermes-agent/pull/1320))
|
||||
- **Gateway stops responding when terminal is closed on headless** (Issue [#1005](https://github.com/NousResearch/hermes-agent/issues/1005))
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
- Cover empty cached Anthropic tool-call turns ([#1222](https://github.com/NousResearch/hermes-agent/pull/1222))
|
||||
- Fix stale CI assumptions in parser and quick-command coverage ([#1236](https://github.com/NousResearch/hermes-agent/pull/1236))
|
||||
- Fix gateway async tests without implicit event loop ([#1278](https://github.com/NousResearch/hermes-agent/pull/1278))
|
||||
- Make gateway async tests xdist-safe ([#1281](https://github.com/NousResearch/hermes-agent/pull/1281))
|
||||
- Cross-timezone naive timestamp regression for cron ([#1319](https://github.com/NousResearch/hermes-agent/pull/1319))
|
||||
- Isolate codex provider tests from local env ([#1335](https://github.com/NousResearch/hermes-agent/pull/1335))
|
||||
- Lock retry replacement semantics ([#1379](https://github.com/NousResearch/hermes-agent/pull/1379))
|
||||
- Improve error logging in session search tool — by @aydnOktay ([#1533](https://github.com/NousResearch/hermes-agent/pull/1533))
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- Comprehensive SOUL.md guide ([#1315](https://github.com/NousResearch/hermes-agent/pull/1315))
|
||||
- Voice mode documentation ([#1316](https://github.com/NousResearch/hermes-agent/pull/1316), [#1362](https://github.com/NousResearch/hermes-agent/pull/1362))
|
||||
- Provider contribution guide ([#1361](https://github.com/NousResearch/hermes-agent/pull/1361))
|
||||
- ACP and internal systems implementation guides ([#1259](https://github.com/NousResearch/hermes-agent/pull/1259))
|
||||
- Expand Docusaurus coverage across CLI, tools, skills, and skins ([#1232](https://github.com/NousResearch/hermes-agent/pull/1232))
|
||||
- Terminal backend and Windows troubleshooting ([#1297](https://github.com/NousResearch/hermes-agent/pull/1297))
|
||||
- Skills hub reference section ([#1317](https://github.com/NousResearch/hermes-agent/pull/1317))
|
||||
- Checkpoint, /rollback, and git worktrees guide ([#1493](https://github.com/NousResearch/hermes-agent/pull/1493), [#1524](https://github.com/NousResearch/hermes-agent/pull/1524))
|
||||
- CLI status bar and /usage reference ([#1523](https://github.com/NousResearch/hermes-agent/pull/1523))
|
||||
- Fallback providers + /background command docs ([#1430](https://github.com/NousResearch/hermes-agent/pull/1430))
|
||||
- Gateway service scopes docs ([#1378](https://github.com/NousResearch/hermes-agent/pull/1378))
|
||||
- Slack thread reply behavior docs ([#1407](https://github.com/NousResearch/hermes-agent/pull/1407))
|
||||
- Redesigned landing page with Nous blue palette — by @austinpickett ([#974](https://github.com/NousResearch/hermes-agent/pull/974))
|
||||
- Fix several documentation typos — by @JackTheGit ([#953](https://github.com/NousResearch/hermes-agent/pull/953))
|
||||
- Stabilize website diagrams ([#1405](https://github.com/NousResearch/hermes-agent/pull/1405))
|
||||
- CLI vs messaging quick reference in README ([#1491](https://github.com/NousResearch/hermes-agent/pull/1491))
|
||||
- Add search to Docusaurus ([#1053](https://github.com/NousResearch/hermes-agent/pull/1053))
|
||||
- Home Assistant integration docs ([#1170](https://github.com/NousResearch/hermes-agent/pull/1170))
|
||||
|
||||
---
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
### Core
|
||||
- **@teknium1** — 220+ PRs spanning every area of the codebase
|
||||
|
||||
### Top Community Contributors
|
||||
|
||||
- **@0xbyt4** (4 PRs) — Anthropic adapter fixes (max_tokens, fallback crash, 429/529 retry), Slack file upload thread context, setup NameError fix
|
||||
- **@erosika** (1 PR) — Honcho memory integration: async writes, memory modes, session title integration
|
||||
- **@SHL0MS** (2 PRs) — ASCII video skill design patterns and refactoring
|
||||
- **@alt-glitch** (2 PRs) — Persistent shell mode for local/SSH backends, setuptools packaging fix
|
||||
- **@arceus77-7** (2 PRs) — 1Password skill, fix skills list mislabeling
|
||||
- **@kshitijk4poor** (1 PR) — OpenClaw migration during setup wizard
|
||||
- **@ASRagab** (1 PR) — Fix adaptive thinking for Claude 4.6 models
|
||||
- **@eren-karakus0** (1 PR) — Strip Hermes provider env vars from subprocess environment
|
||||
- **@mr-emmett-one** (1 PR) — Fix DeepSeek V3 parser multi-tool call support
|
||||
- **@jplew** (1 PR) — Gateway restart on retryable startup failures
|
||||
- **@brandtcormorant** (1 PR) — Fix Anthropic cache control for empty text blocks
|
||||
- **@aydnOktay** (1 PR) — Improve error logging in session search tool
|
||||
- **@austinpickett** (1 PR) — Landing page redesign with Nous blue palette
|
||||
- **@JackTheGit** (1 PR) — Documentation typo fixes
|
||||
|
||||
### All Contributors
|
||||
|
||||
@0xbyt4, @alt-glitch, @arceus77-7, @ASRagab, @austinpickett, @aydnOktay, @brandtcormorant, @eren-karakus0, @erosika, @JackTheGit, @jplew, @kshitijk4poor, @mr-emmett-one, @SHL0MS, @teknium1
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: [v2026.3.12...v2026.3.17](https://github.com/NousResearch/hermes-agent/compare/v2026.3.12...v2026.3.17)
|
||||
+164
-5
@@ -42,7 +42,7 @@ from acp_adapter.events import (
|
||||
make_tool_progress_cb,
|
||||
)
|
||||
from acp_adapter.permissions import make_approval_callback
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -226,10 +226,19 @@ class HermesACPAgent(acp.Agent):
|
||||
logger.error("prompt: session %s not found", session_id)
|
||||
return PromptResponse(stop_reason="refusal")
|
||||
|
||||
user_text = _extract_text(prompt)
|
||||
if not user_text.strip():
|
||||
user_text = _extract_text(prompt).strip()
|
||||
if not user_text:
|
||||
return PromptResponse(stop_reason="end_turn")
|
||||
|
||||
# Intercept slash commands — handle locally without calling the LLM
|
||||
if user_text.startswith("/"):
|
||||
response_text = self._handle_slash_command(user_text, state)
|
||||
if response_text is not None:
|
||||
if self._conn:
|
||||
update = acp.update_agent_message_text(response_text)
|
||||
await self._conn.session_update(session_id, update)
|
||||
return PromptResponse(stop_reason="end_turn")
|
||||
|
||||
logger.info("Prompt on session %s: %s", session_id, user_text[:100])
|
||||
|
||||
conn = self._conn
|
||||
@@ -295,6 +304,8 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
if result.get("messages"):
|
||||
state.history = result["messages"]
|
||||
# Persist updated history so sessions survive process restarts.
|
||||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response and conn:
|
||||
@@ -315,19 +326,167 @@ class HermesACPAgent(acp.Agent):
|
||||
stop_reason = "cancelled" if state.cancel_event and state.cancel_event.is_set() else "end_turn"
|
||||
return PromptResponse(stop_reason=stop_reason, usage=usage)
|
||||
|
||||
# ---- Model switching ----------------------------------------------------
|
||||
# ---- Slash commands (headless) -------------------------------------------
|
||||
|
||||
_SLASH_COMMANDS = {
|
||||
"help": "Show available commands",
|
||||
"model": "Show or change current model",
|
||||
"tools": "List available tools",
|
||||
"context": "Show conversation context info",
|
||||
"reset": "Clear conversation history",
|
||||
"compact": "Compress conversation context",
|
||||
"version": "Show Hermes version",
|
||||
}
|
||||
|
||||
def _handle_slash_command(self, text: str, state: SessionState) -> str | None:
|
||||
"""Dispatch a slash command and return the response text.
|
||||
|
||||
Returns ``None`` for unrecognized commands so they fall through
|
||||
to the LLM (the user may have typed ``/something`` as prose).
|
||||
"""
|
||||
parts = text.split(maxsplit=1)
|
||||
cmd = parts[0].lstrip("/").lower()
|
||||
args = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
handler = {
|
||||
"help": self._cmd_help,
|
||||
"model": self._cmd_model,
|
||||
"tools": self._cmd_tools,
|
||||
"context": self._cmd_context,
|
||||
"reset": self._cmd_reset,
|
||||
"compact": self._cmd_compact,
|
||||
"version": self._cmd_version,
|
||||
}.get(cmd)
|
||||
|
||||
if handler is None:
|
||||
return None # not a known command — let the LLM handle it
|
||||
|
||||
try:
|
||||
return handler(args, state)
|
||||
except Exception as e:
|
||||
logger.error("Slash command /%s error: %s", cmd, e, exc_info=True)
|
||||
return f"Error executing /{cmd}: {e}"
|
||||
|
||||
def _cmd_help(self, args: str, state: SessionState) -> str:
|
||||
lines = ["Available commands:", ""]
|
||||
for cmd, desc in self._SLASH_COMMANDS.items():
|
||||
lines.append(f" /{cmd:10s} {desc}")
|
||||
lines.append("")
|
||||
lines.append("Unrecognized /commands are sent to the model as normal messages.")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _cmd_model(self, args: str, state: SessionState) -> str:
|
||||
if not args:
|
||||
model = state.model or getattr(state.agent, "model", "unknown")
|
||||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
logger.info("Session %s: model switched to %s", state.session_id, new_model)
|
||||
return f"Model switched to: {new_model}\nProvider: {provider_label}"
|
||||
|
||||
def _cmd_tools(self, args: str, state: SessionState) -> str:
|
||||
try:
|
||||
from model_tools import get_tool_definitions
|
||||
toolsets = getattr(state.agent, "enabled_toolsets", None) or ["hermes-acp"]
|
||||
tools = get_tool_definitions(enabled_toolsets=toolsets, quiet_mode=True)
|
||||
if not tools:
|
||||
return "No tools available."
|
||||
lines = [f"Available tools ({len(tools)}):"]
|
||||
for t in tools:
|
||||
name = t.get("function", {}).get("name", "?")
|
||||
desc = t.get("function", {}).get("description", "")
|
||||
# Truncate long descriptions
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
lines.append(f" {name}: {desc}")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Could not list tools: {e}"
|
||||
|
||||
def _cmd_context(self, args: str, state: SessionState) -> str:
|
||||
n_messages = len(state.history)
|
||||
if n_messages == 0:
|
||||
return "Conversation is empty (no messages yet)."
|
||||
# Count by role
|
||||
roles: dict[str, int] = {}
|
||||
for msg in state.history:
|
||||
role = msg.get("role", "unknown")
|
||||
roles[role] = roles.get(role, 0) + 1
|
||||
lines = [
|
||||
f"Conversation: {n_messages} messages",
|
||||
f" user: {roles.get('user', 0)}, assistant: {roles.get('assistant', 0)}, "
|
||||
f"tool: {roles.get('tool', 0)}, system: {roles.get('system', 0)}",
|
||||
]
|
||||
model = state.model or getattr(state.agent, "model", "")
|
||||
if model:
|
||||
lines.append(f"Model: {model}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _cmd_reset(self, args: str, state: SessionState) -> str:
|
||||
state.history.clear()
|
||||
self.session_manager.save_session(state.session_id)
|
||||
return "Conversation history cleared."
|
||||
|
||||
def _cmd_compact(self, args: str, state: SessionState) -> str:
|
||||
if not state.history:
|
||||
return "Nothing to compress — conversation is empty."
|
||||
try:
|
||||
agent = state.agent
|
||||
if hasattr(agent, "compress_context"):
|
||||
agent.compress_context(state.history)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
return f"Context compressed. Messages: {len(state.history)}"
|
||||
return "Context compression not available for this agent."
|
||||
except Exception as e:
|
||||
return f"Compression failed: {e}"
|
||||
|
||||
def _cmd_version(self, args: str, state: SessionState) -> str:
|
||||
return f"Hermes Agent v{HERMES_VERSION}"
|
||||
|
||||
# ---- Model switching (ACP protocol method) -------------------------------
|
||||
|
||||
async def set_session_model(
|
||||
self, model_id: str, session_id: str, **kwargs: Any
|
||||
):
|
||||
"""Switch the model for a session."""
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
return None
|
||||
|
||||
+295
-39
@@ -1,7 +1,15 @@
|
||||
"""ACP session manager — maps ACP sessions to Hermes AIAgent instances."""
|
||||
"""ACP session manager — maps ACP sessions to Hermes AIAgent instances.
|
||||
|
||||
Sessions are persisted to the shared SessionDB (``~/.hermes/state.db``) so they
|
||||
survive process restarts and appear in ``session_search``. When the editor
|
||||
reconnects after idle/restart, the ``load_session`` / ``resume_session`` calls
|
||||
find the persisted session in the database and restore the full conversation
|
||||
history.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@@ -46,18 +54,26 @@ class SessionState:
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Thread-safe manager for ACP sessions backed by Hermes AIAgent instances."""
|
||||
"""Thread-safe manager for ACP sessions backed by Hermes AIAgent instances.
|
||||
|
||||
def __init__(self, agent_factory=None):
|
||||
Sessions are held in-memory for fast access **and** persisted to the
|
||||
shared SessionDB so they survive process restarts and are searchable
|
||||
via ``session_search``.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_factory=None, db=None):
|
||||
"""
|
||||
Args:
|
||||
agent_factory: Optional callable that creates an AIAgent-like object.
|
||||
Used by tests. When omitted, a real AIAgent is created
|
||||
using the current Hermes runtime provider configuration.
|
||||
db: Optional SessionDB instance. When omitted, the default
|
||||
SessionDB (``~/.hermes/state.db``) is lazily created.
|
||||
"""
|
||||
self._sessions: Dict[str, SessionState] = {}
|
||||
self._lock = Lock()
|
||||
self._agent_factory = agent_factory
|
||||
self._db_instance = db # None → lazy-init on first use
|
||||
|
||||
# ---- public API ---------------------------------------------------------
|
||||
|
||||
@@ -77,54 +93,67 @@ class SessionManager:
|
||||
with self._lock:
|
||||
self._sessions[session_id] = state
|
||||
_register_task_cwd(session_id, cwd)
|
||||
self._persist(state)
|
||||
logger.info("Created ACP session %s (cwd=%s)", session_id, cwd)
|
||||
return state
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionState]:
|
||||
"""Return the session for *session_id*, or ``None``."""
|
||||
"""Return the session for *session_id*, or ``None``.
|
||||
|
||||
If the session is not in memory but exists in the database (e.g. after
|
||||
a process restart), it is transparently restored.
|
||||
"""
|
||||
with self._lock:
|
||||
return self._sessions.get(session_id)
|
||||
state = self._sessions.get(session_id)
|
||||
if state is not None:
|
||||
return state
|
||||
# Attempt to restore from database.
|
||||
return self._restore(session_id)
|
||||
|
||||
def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session. Returns True if it existed."""
|
||||
"""Remove a session from memory and database. Returns True if it existed."""
|
||||
with self._lock:
|
||||
existed = self._sessions.pop(session_id, None) is not None
|
||||
if existed:
|
||||
db_existed = self._delete_persisted(session_id)
|
||||
if existed or db_existed:
|
||||
_clear_task_cwd(session_id)
|
||||
return existed
|
||||
return existed or db_existed
|
||||
|
||||
def fork_session(self, session_id: str, cwd: str = ".") -> Optional[SessionState]:
|
||||
"""Deep-copy a session's history into a new session."""
|
||||
import threading
|
||||
|
||||
with self._lock:
|
||||
original = self._sessions.get(session_id)
|
||||
if original is None:
|
||||
return None
|
||||
original = self.get_session(session_id) # checks DB too
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
agent = self._make_agent(
|
||||
session_id=new_id,
|
||||
cwd=cwd,
|
||||
model=original.model or None,
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=new_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=getattr(agent, "model", original.model) or original.model,
|
||||
history=copy.deepcopy(original.history),
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
new_id = str(uuid.uuid4())
|
||||
agent = self._make_agent(
|
||||
session_id=new_id,
|
||||
cwd=cwd,
|
||||
model=original.model or None,
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=new_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=getattr(agent, "model", original.model) or original.model,
|
||||
history=copy.deepcopy(original.history),
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
with self._lock:
|
||||
self._sessions[new_id] = state
|
||||
_register_task_cwd(new_id, cwd)
|
||||
self._persist(state)
|
||||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions."""
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
return [
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
@@ -134,23 +163,245 @@ class SessionManager:
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
"""Update the working directory for a session and its tool overrides."""
|
||||
with self._lock:
|
||||
state = self._sessions.get(session_id)
|
||||
if state is None:
|
||||
return None
|
||||
state.cwd = cwd
|
||||
state = self.get_session(session_id) # checks DB too
|
||||
if state is None:
|
||||
return None
|
||||
state.cwd = cwd
|
||||
_register_task_cwd(session_id, cwd)
|
||||
self._persist(state)
|
||||
return state
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove all sessions and clear task-specific cwd overrides."""
|
||||
"""Remove all sessions (memory and database) and clear task-specific cwd overrides."""
|
||||
with self._lock:
|
||||
session_ids = list(self._sessions.keys())
|
||||
self._sessions.clear()
|
||||
for session_id in session_ids:
|
||||
_clear_task_cwd(session_id)
|
||||
self._delete_persisted(session_id)
|
||||
# Also remove any DB-only ACP sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=10000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
_clear_task_cwd(sid)
|
||||
db.delete_session(sid)
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup ACP sessions from DB", exc_info=True)
|
||||
|
||||
def save_session(self, session_id: str) -> None:
|
||||
"""Persist the current state of a session to the database.
|
||||
|
||||
Called by the server after prompt completion, slash commands that
|
||||
mutate history, and model switches.
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._sessions.get(session_id)
|
||||
if state is not None:
|
||||
self._persist(state)
|
||||
|
||||
# ---- persistence via SessionDB ------------------------------------------
|
||||
|
||||
def _get_db(self):
|
||||
"""Lazily initialise and return the SessionDB instance.
|
||||
|
||||
Returns ``None`` if the DB is unavailable (e.g. import error in a
|
||||
minimal test environment).
|
||||
|
||||
Note: we resolve ``HERMES_HOME`` dynamically rather than relying on
|
||||
the module-level ``DEFAULT_DB_PATH`` constant, because that constant
|
||||
is evaluated at import time and won't reflect env-var changes made
|
||||
later (e.g. by the test fixture ``_isolate_hermes_home``).
|
||||
"""
|
||||
if self._db_instance is not None:
|
||||
return self._db_instance
|
||||
try:
|
||||
import os
|
||||
from pathlib import Path
|
||||
from hermes_state import SessionDB
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
self._db_instance = SessionDB(db_path=hermes_home / "state.db")
|
||||
return self._db_instance
|
||||
except Exception:
|
||||
logger.debug("SessionDB unavailable for ACP persistence", exc_info=True)
|
||||
return None
|
||||
|
||||
def _persist(self, state: SessionState) -> None:
|
||||
"""Write session state to the database.
|
||||
|
||||
Creates the session record if it doesn't exist, then replaces all
|
||||
stored messages with the current in-memory history.
|
||||
"""
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return
|
||||
|
||||
# Ensure model is a plain string (not a MagicMock or other proxy).
|
||||
model_str = str(state.model) if state.model else None
|
||||
session_meta = {"cwd": state.cwd}
|
||||
provider = getattr(state.agent, "provider", None)
|
||||
base_url = getattr(state.agent, "base_url", None)
|
||||
api_mode = getattr(state.agent, "api_mode", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
session_meta["provider"] = provider.strip()
|
||||
if isinstance(base_url, str) and base_url.strip():
|
||||
session_meta["base_url"] = base_url.strip()
|
||||
if isinstance(api_mode, str) and api_mode.strip():
|
||||
session_meta["api_mode"] = api_mode.strip()
|
||||
cwd_json = json.dumps(session_meta)
|
||||
|
||||
try:
|
||||
# Ensure the session record exists.
|
||||
existing = db.get_session(state.session_id)
|
||||
if existing is None:
|
||||
db.create_session(
|
||||
session_id=state.session_id,
|
||||
source="acp",
|
||||
model=model_str,
|
||||
model_config={"cwd": state.cwd},
|
||||
)
|
||||
else:
|
||||
# Update model_config (contains cwd) if changed.
|
||||
try:
|
||||
with db._lock:
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET model_config = ?, model = COALESCE(?, model) WHERE id = ?",
|
||||
(cwd_json, model_str, state.session_id),
|
||||
)
|
||||
db._conn.commit()
|
||||
except Exception:
|
||||
logger.debug("Failed to update ACP session metadata", exc_info=True)
|
||||
|
||||
# Replace stored messages with current history.
|
||||
db.clear_messages(state.session_id)
|
||||
for msg in state.history:
|
||||
db.append_message(
|
||||
session_id=state.session_id,
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name") or msg.get("name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist ACP session %s", state.session_id, exc_info=True)
|
||||
|
||||
def _restore(self, session_id: str) -> Optional[SessionState]:
|
||||
"""Load a session from the database into memory, recreating the AIAgent."""
|
||||
import threading
|
||||
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
row = db.get_session(session_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to query DB for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
# Only restore ACP sessions.
|
||||
if row.get("source") != "acp":
|
||||
return None
|
||||
|
||||
# Extract cwd from model_config.
|
||||
cwd = "."
|
||||
requested_provider = row.get("billing_provider")
|
||||
restored_base_url = row.get("billing_base_url")
|
||||
restored_api_mode = None
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
meta = json.loads(mc)
|
||||
if isinstance(meta, dict):
|
||||
cwd = meta.get("cwd", ".")
|
||||
requested_provider = meta.get("provider") or requested_provider
|
||||
restored_base_url = meta.get("base_url") or restored_base_url
|
||||
restored_api_mode = meta.get("api_mode") or restored_api_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
model = row.get("model") or None
|
||||
|
||||
# Load conversation history.
|
||||
try:
|
||||
history = db.get_messages_as_conversation(session_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to load messages for ACP session %s", session_id, exc_info=True)
|
||||
history = []
|
||||
|
||||
try:
|
||||
agent = self._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
model=model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=restored_base_url,
|
||||
api_mode=restored_api_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
|
||||
state = SessionState(
|
||||
session_id=session_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=model or getattr(agent, "model", "") or "",
|
||||
history=history,
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
with self._lock:
|
||||
self._sessions[session_id] = state
|
||||
_register_task_cwd(session_id, cwd)
|
||||
logger.info("Restored ACP session %s from DB (%d messages)", session_id, len(history))
|
||||
return state
|
||||
|
||||
def _delete_persisted(self, session_id: str) -> bool:
|
||||
"""Delete a session from the database. Returns True if it existed."""
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return False
|
||||
try:
|
||||
return db.delete_session(session_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to delete ACP session %s from DB", session_id, exc_info=True)
|
||||
return False
|
||||
|
||||
# ---- internal -----------------------------------------------------------
|
||||
|
||||
@@ -160,6 +411,9 @@ class SessionManager:
|
||||
session_id: str,
|
||||
cwd: str,
|
||||
model: str | None = None,
|
||||
requested_provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_mode: str | None = None,
|
||||
):
|
||||
if self._agent_factory is not None:
|
||||
return self._agent_factory()
|
||||
@@ -171,10 +425,10 @@ class SessionManager:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
default_model = "anthropic/claude-opus-4.6"
|
||||
requested_provider = None
|
||||
config_provider = None
|
||||
if isinstance(model_cfg, dict):
|
||||
default_model = str(model_cfg.get("default") or default_model)
|
||||
requested_provider = model_cfg.get("provider")
|
||||
config_provider = model_cfg.get("provider")
|
||||
elif isinstance(model_cfg, str) and model_cfg.strip():
|
||||
default_model = model_cfg.strip()
|
||||
|
||||
@@ -187,13 +441,15 @@ class SessionManager:
|
||||
}
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=requested_provider)
|
||||
runtime = resolve_runtime_provider(requested=requested_provider or config_provider)
|
||||
kwargs.update(
|
||||
{
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"api_mode": api_mode or runtime.get("api_mode"),
|
||||
"base_url": base_url or runtime.get("base_url"),
|
||||
"api_key": runtime.get("api_key"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
|
||||
+369
-19
@@ -45,14 +45,49 @@ _COMMON_BETAS = [
|
||||
"fine-grained-tool-streaming-2025-05-14",
|
||||
]
|
||||
|
||||
# Additional beta headers required for OAuth/subscription auth
|
||||
# Both clawdbot and OpenCode include claude-code-20250219 alongside oauth-2025-04-20.
|
||||
# Without claude-code-20250219, Anthropic's API rejects OAuth tokens with 401.
|
||||
# Additional beta headers required for OAuth/subscription auth.
|
||||
# Matches what Claude Code (and pi-ai / OpenCode) send.
|
||||
_OAUTH_ONLY_BETAS = [
|
||||
"claude-code-20250219",
|
||||
"oauth-2025-04-20",
|
||||
]
|
||||
|
||||
# Claude Code identity — required for OAuth requests to be routed correctly.
|
||||
# Without these, Anthropic's infrastructure intermittently 500s OAuth traffic.
|
||||
# The version must stay reasonably current — Anthropic rejects OAuth requests
|
||||
# when the spoofed user-agent version is too far behind the actual release.
|
||||
_CLAUDE_CODE_VERSION_FALLBACK = "2.1.74"
|
||||
|
||||
|
||||
def _detect_claude_code_version() -> str:
|
||||
"""Detect the installed Claude Code version, fall back to a static constant.
|
||||
|
||||
Anthropic's OAuth infrastructure validates the user-agent version and may
|
||||
reject requests with a version that's too old. Detecting dynamically means
|
||||
users who keep Claude Code updated never hit stale-version 400s.
|
||||
"""
|
||||
import subprocess as _sp
|
||||
|
||||
for cmd in ("claude", "claude-code"):
|
||||
try:
|
||||
result = _sp.run(
|
||||
[cmd, "--version"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
# Output is like "2.1.74 (Claude Code)" or just "2.1.74"
|
||||
version = result.stdout.strip().split()[0]
|
||||
if version and version[0].isdigit():
|
||||
return version
|
||||
except Exception:
|
||||
pass
|
||||
return _CLAUDE_CODE_VERSION_FALLBACK
|
||||
|
||||
|
||||
_CLAUDE_CODE_VERSION = _detect_claude_code_version()
|
||||
_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
_MCP_TOOL_PREFIX = "mcp_"
|
||||
|
||||
|
||||
def _is_oauth_token(key: str) -> bool:
|
||||
"""Check if the key is an OAuth/setup token (not a regular Console API key).
|
||||
@@ -88,10 +123,16 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
kwargs["base_url"] = base_url
|
||||
|
||||
if _is_oauth_token(api_key):
|
||||
# OAuth access token / setup-token → Bearer auth + beta headers
|
||||
# OAuth access token / setup-token → Bearer auth + Claude Code identity.
|
||||
# Anthropic routes OAuth requests based on user-agent and headers;
|
||||
# without Claude Code's fingerprint, requests get intermittent 500s.
|
||||
all_betas = _COMMON_BETAS + _OAUTH_ONLY_BETAS
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(all_betas)}
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
"user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"x-app": "cli",
|
||||
}
|
||||
else:
|
||||
# Regular API key → x-api-key header + common betas
|
||||
kwargs["api_key"] = api_key
|
||||
@@ -189,7 +230,10 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
req = urllib.request.Request(
|
||||
"https://console.anthropic.com/v1/oauth/token",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
@@ -332,12 +376,24 @@ def resolve_anthropic_token() -> Optional[str]:
|
||||
return preferred
|
||||
return cc_token
|
||||
|
||||
# 3. Claude Code credential file
|
||||
# 3. Hermes-managed OAuth credentials (~/.hermes/.anthropic_oauth.json)
|
||||
hermes_creds = read_hermes_oauth_credentials()
|
||||
if hermes_creds:
|
||||
if is_claude_code_token_valid(hermes_creds):
|
||||
logger.debug("Using Hermes-managed OAuth credentials")
|
||||
return hermes_creds["accessToken"]
|
||||
# Expired — try refresh
|
||||
logger.debug("Hermes OAuth token expired — attempting refresh")
|
||||
refreshed = refresh_hermes_oauth_token()
|
||||
if refreshed:
|
||||
return refreshed
|
||||
|
||||
# 4. Claude Code credential file
|
||||
resolved_claude_token = _resolve_claude_code_token_from_credentials(creds)
|
||||
if resolved_claude_token:
|
||||
return resolved_claude_token
|
||||
|
||||
# 4. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY.
|
||||
# 5. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY.
|
||||
# This remains as a compatibility fallback for pre-migration Hermes configs.
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
||||
if api_key:
|
||||
@@ -386,24 +442,235 @@ def run_oauth_setup_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
# ── Hermes-native PKCE OAuth flow ────────────────────────────────────────
|
||||
# Mirrors the flow used by Claude Code, pi-ai, and OpenCode.
|
||||
# Stores credentials in ~/.hermes/.anthropic_oauth.json (our own file).
|
||||
|
||||
_OAUTH_CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
_OAUTH_TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"
|
||||
_OAUTH_REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback"
|
||||
_OAUTH_SCOPES = "org:create_api_key user:profile user:inference"
|
||||
_HERMES_OAUTH_FILE = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes"))) / ".anthropic_oauth.json"
|
||||
|
||||
|
||||
def _generate_pkce() -> tuple:
|
||||
"""Generate PKCE code_verifier and code_challenge (S256)."""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).rstrip(b"=").decode()
|
||||
challenge = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode()).digest()
|
||||
).rstrip(b"=").decode()
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def run_hermes_oauth_login() -> Optional[str]:
|
||||
"""Run Hermes-native OAuth PKCE flow for Claude Pro/Max subscription.
|
||||
|
||||
Opens a browser to claude.ai for authorization, prompts for the code,
|
||||
exchanges it for tokens, and stores them in ~/.hermes/.anthropic_oauth.json.
|
||||
|
||||
Returns the access token on success, None on failure.
|
||||
"""
|
||||
import time
|
||||
import webbrowser
|
||||
|
||||
verifier, challenge = _generate_pkce()
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"code": "true",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"scope": _OAUTH_SCOPES,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": verifier,
|
||||
}
|
||||
from urllib.parse import urlencode
|
||||
auth_url = f"https://claude.ai/oauth/authorize?{urlencode(params)}"
|
||||
|
||||
print()
|
||||
print("Authorize Hermes with your Claude Pro/Max subscription.")
|
||||
print()
|
||||
print("╭─ Claude Pro/Max Authorization ────────────────────╮")
|
||||
print("│ │")
|
||||
print("│ Open this link in your browser: │")
|
||||
print("╰───────────────────────────────────────────────────╯")
|
||||
print()
|
||||
print(f" {auth_url}")
|
||||
print()
|
||||
|
||||
# Try to open browser automatically (works on desktop, silently fails on headless/SSH)
|
||||
try:
|
||||
webbrowser.open(auth_url)
|
||||
print(" (Browser opened automatically)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print()
|
||||
print("After authorizing, you'll see a code. Paste it below.")
|
||||
print()
|
||||
try:
|
||||
auth_code = input("Authorization code: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return None
|
||||
|
||||
if not auth_code:
|
||||
print("No code entered.")
|
||||
return None
|
||||
|
||||
# Split code#state format
|
||||
splits = auth_code.split("#")
|
||||
code = splits[0]
|
||||
state = splits[1] if len(splits) > 1 else ""
|
||||
|
||||
# Exchange code for tokens
|
||||
try:
|
||||
import urllib.request
|
||||
exchange_data = json.dumps({
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"code": code,
|
||||
"state": state,
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"code_verifier": verifier,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=exchange_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
except Exception as e:
|
||||
print(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
access_token = result.get("access_token", "")
|
||||
refresh_token = result.get("refresh_token", "")
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
print("No access token in response.")
|
||||
return None
|
||||
|
||||
# Store credentials
|
||||
expires_at_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_save_hermes_oauth_credentials(access_token, refresh_token, expires_at_ms)
|
||||
|
||||
# Also write to Claude Code's credential file for backward compat
|
||||
_write_claude_code_credentials(access_token, refresh_token, expires_at_ms)
|
||||
|
||||
print("Authentication successful!")
|
||||
return access_token
|
||||
|
||||
|
||||
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
|
||||
data = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
try:
|
||||
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
_HERMES_OAUTH_FILE.chmod(0o600)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
|
||||
|
||||
|
||||
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
|
||||
if _HERMES_OAUTH_FILE.exists():
|
||||
try:
|
||||
data = json.loads(_HERMES_OAUTH_FILE.read_text(encoding="utf-8"))
|
||||
if data.get("accessToken"):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read Hermes OAuth credentials: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def refresh_hermes_oauth_token() -> Optional[str]:
|
||||
"""Refresh the Hermes-managed OAuth token using the stored refresh token.
|
||||
|
||||
Returns the new access token, or None if refresh fails.
|
||||
"""
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
creds = read_hermes_oauth_credentials()
|
||||
if not creds or not creds.get("refreshToken"):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.dumps({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": creds["refreshToken"],
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
|
||||
new_access = result.get("access_token", "")
|
||||
new_refresh = result.get("refresh_token", creds["refreshToken"])
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if new_access:
|
||||
new_expires_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_save_hermes_oauth_credentials(new_access, new_refresh, new_expires_ms)
|
||||
# Also update Claude Code's credential file
|
||||
_write_claude_code_credentials(new_access, new_refresh, new_expires_ms)
|
||||
logger.debug("Successfully refreshed Hermes OAuth token")
|
||||
return new_access
|
||||
except Exception as e:
|
||||
logger.debug("Failed to refresh Hermes OAuth token: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message / tool / response format conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_model_name(model: str) -> str:
|
||||
def normalize_model_name(model: str, preserve_dots: bool = False) -> str:
|
||||
"""Normalize a model name for the Anthropic API.
|
||||
|
||||
- Strips 'anthropic/' prefix (OpenRouter format, case-insensitive)
|
||||
- Converts dots to hyphens in version numbers (OpenRouter uses dots,
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6)
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6), unless
|
||||
preserve_dots is True (e.g. for Alibaba/DashScope: qwen3.5-plus).
|
||||
"""
|
||||
lower = model.lower()
|
||||
if lower.startswith("anthropic/"):
|
||||
model = model[len("anthropic/"):]
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
if not preserve_dots:
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
|
||||
|
||||
@@ -599,6 +866,8 @@ def convert_messages_to_anthropic(
|
||||
else:
|
||||
blocks.append({"type": "text", "text": str(content)})
|
||||
for tc in m.get("tool_calls", []):
|
||||
if not tc or not isinstance(tc, dict):
|
||||
continue
|
||||
fn = tc.get("function", {})
|
||||
args = fn.get("arguments", "{}")
|
||||
try:
|
||||
@@ -670,6 +939,26 @@ def convert_messages_to_anthropic(
|
||||
if not m["content"]:
|
||||
m["content"] = [{"type": "text", "text": "(tool call removed)"}]
|
||||
|
||||
# Strip orphaned tool_result blocks (no matching tool_use precedes them).
|
||||
# This is the mirror of the above: context compression or session truncation
|
||||
# can remove an assistant message containing a tool_use while leaving the
|
||||
# subsequent tool_result intact. Anthropic rejects these with a 400.
|
||||
tool_use_ids = set()
|
||||
for m in result:
|
||||
if m["role"] == "assistant" and isinstance(m["content"], list):
|
||||
for block in m["content"]:
|
||||
if block.get("type") == "tool_use":
|
||||
tool_use_ids.add(block.get("id"))
|
||||
for m in result:
|
||||
if m["role"] == "user" and isinstance(m["content"], list):
|
||||
m["content"] = [
|
||||
b
|
||||
for b in m["content"]
|
||||
if b.get("type") != "tool_result" or b.get("tool_use_id") in tool_use_ids
|
||||
]
|
||||
if not m["content"]:
|
||||
m["content"] = [{"type": "text", "text": "(tool result removed)"}]
|
||||
|
||||
# Enforce strict role alternation (Anthropic rejects consecutive same-role messages)
|
||||
fixed = []
|
||||
for m in result:
|
||||
@@ -698,8 +987,12 @@ def convert_messages_to_anthropic(
|
||||
elif isinstance(prev_blocks, str) and isinstance(curr_blocks, str):
|
||||
fixed[-1]["content"] = prev_blocks + "\n" + curr_blocks
|
||||
else:
|
||||
# Keep the later message
|
||||
fixed[-1] = m
|
||||
# Mixed types — normalize both to list and merge
|
||||
if isinstance(prev_blocks, str):
|
||||
prev_blocks = [{"type": "text", "text": prev_blocks}]
|
||||
if isinstance(curr_blocks, str):
|
||||
curr_blocks = [{"type": "text", "text": curr_blocks}]
|
||||
fixed[-1]["content"] = prev_blocks + curr_blocks
|
||||
else:
|
||||
fixed.append(m)
|
||||
result = fixed
|
||||
@@ -714,14 +1007,63 @@ def build_anthropic_kwargs(
|
||||
max_tokens: Optional[int],
|
||||
reasoning_config: Optional[Dict[str, Any]],
|
||||
tool_choice: Optional[str] = None,
|
||||
is_oauth: bool = False,
|
||||
preserve_dots: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create()."""
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
|
||||
When *preserve_dots* is True, model name dots are not converted to hyphens
|
||||
(for Alibaba/DashScope anthropic-compatible endpoints: qwen3.5-plus).
|
||||
"""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model)
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
effective_max_tokens = max_tokens or 16384
|
||||
|
||||
# ── OAuth: Claude Code identity ──────────────────────────────────
|
||||
if is_oauth:
|
||||
# 1. Prepend Claude Code system prompt identity
|
||||
cc_block = {"type": "text", "text": _CLAUDE_CODE_SYSTEM_PREFIX}
|
||||
if isinstance(system, list):
|
||||
system = [cc_block] + system
|
||||
elif isinstance(system, str) and system:
|
||||
system = [cc_block, {"type": "text", "text": system}]
|
||||
else:
|
||||
system = [cc_block]
|
||||
|
||||
# 2. Sanitize system prompt — replace product name references
|
||||
# to avoid Anthropic's server-side content filters.
|
||||
for block in system:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
text = text.replace("Hermes Agent", "Claude Code")
|
||||
text = text.replace("Hermes agent", "Claude Code")
|
||||
text = text.replace("hermes-agent", "claude-code")
|
||||
text = text.replace("Nous Research", "Anthropic")
|
||||
block["text"] = text
|
||||
|
||||
# 3. Prefix tool names with mcp_ (Claude Code convention)
|
||||
if anthropic_tools:
|
||||
for tool in anthropic_tools:
|
||||
if "name" in tool:
|
||||
tool["name"] = _MCP_TOOL_PREFIX + tool["name"]
|
||||
|
||||
# 4. Prefix tool names in message history (tool_use and tool_result blocks)
|
||||
for msg in anthropic_messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "tool_use" and "name" in block:
|
||||
if not block["name"].startswith(_MCP_TOOL_PREFIX):
|
||||
block["name"] = _MCP_TOOL_PREFIX + block["name"]
|
||||
elif block.get("type") == "tool_result" and "tool_use_id" in block:
|
||||
pass # tool_result uses ID, not name
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
@@ -739,7 +1081,8 @@ def build_anthropic_kwargs(
|
||||
elif tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "none":
|
||||
pass # Don't send tool_choice — Anthropic will use tools if needed
|
||||
# Anthropic has no tool_choice "none" — omit tools entirely to prevent use
|
||||
kwargs.pop("tools", None)
|
||||
elif isinstance(tool_choice, str):
|
||||
# Specific tool name
|
||||
kwargs["tool_choice"] = {"type": "tool", "name": tool_choice}
|
||||
@@ -768,11 +1111,15 @@ def build_anthropic_kwargs(
|
||||
|
||||
def normalize_anthropic_response(
|
||||
response,
|
||||
strip_tool_prefix: bool = False,
|
||||
) -> Tuple[SimpleNamespace, str]:
|
||||
"""Normalize Anthropic response to match the shape expected by AIAgent.
|
||||
|
||||
Returns (assistant_message, finish_reason) where assistant_message has
|
||||
.content, .tool_calls, and .reasoning attributes.
|
||||
|
||||
When *strip_tool_prefix* is True, removes the ``mcp_`` prefix that was
|
||||
added to tool names for OAuth Claude Code compatibility.
|
||||
"""
|
||||
text_parts = []
|
||||
reasoning_parts = []
|
||||
@@ -784,12 +1131,15 @@ def normalize_anthropic_response(
|
||||
elif block.type == "thinking":
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
name = block.name
|
||||
if strip_tool_prefix and name.startswith(_MCP_TOOL_PREFIX):
|
||||
name = name[len(_MCP_TOOL_PREFIX):]
|
||||
tool_calls.append(
|
||||
SimpleNamespace(
|
||||
id=block.id,
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=block.name,
|
||||
name=name,
|
||||
arguments=json.dumps(block.input),
|
||||
),
|
||||
)
|
||||
|
||||
+189
-60
@@ -39,6 +39,8 @@ custom OpenAI-compatible endpoint without touching the main model settings.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -54,9 +56,13 @@ logger = logging.getLogger(__name__)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.5-highspeed",
|
||||
"minimax": "MiniMax-M2.7-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.7-highspeed",
|
||||
"anthropic": "claude-haiku-4-5-20251001",
|
||||
"ai-gateway": "google/gemini-3-flash",
|
||||
"opencode-zen": "gemini-3-flash",
|
||||
"opencode-go": "glm-5",
|
||||
"kilocode": "google/gemini-3-flash-preview",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
@@ -320,9 +326,10 @@ class AsyncCodexAuxiliaryClient:
|
||||
class _AnthropicCompletionsAdapter:
|
||||
"""OpenAI-client-compatible adapter for Anthropic Messages API."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str):
|
||||
def __init__(self, real_client: Any, model: str, is_oauth: bool = False):
|
||||
self._client = real_client
|
||||
self._model = model
|
||||
self._is_oauth = is_oauth
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs, normalize_anthropic_response
|
||||
@@ -351,6 +358,7 @@ class _AnthropicCompletionsAdapter:
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=None,
|
||||
tool_choice=normalized_tool_choice,
|
||||
is_oauth=self._is_oauth,
|
||||
)
|
||||
if temperature is not None:
|
||||
anthropic_kwargs["temperature"] = temperature
|
||||
@@ -389,9 +397,9 @@ class _AnthropicChatShim:
|
||||
class AnthropicAuxiliaryClient:
|
||||
"""OpenAI-client-compatible wrapper over a native Anthropic client."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str):
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str, is_oauth: bool = False):
|
||||
self._real_client = real_client
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model)
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model, is_oauth=is_oauth)
|
||||
self.chat = _AnthropicChatShim(adapter)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
@@ -458,15 +466,30 @@ def _nous_base_url() -> str:
|
||||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
if isinstance(access_token, str) and access_token.strip():
|
||||
return access_token.strip()
|
||||
return None
|
||||
if not isinstance(access_token, str) or not access_token.strip():
|
||||
return None
|
||||
|
||||
# Check JWT expiry — expired tokens block the auto chain and
|
||||
# prevent fallback to working providers (e.g. Anthropic).
|
||||
try:
|
||||
import base64
|
||||
payload = access_token.split(".")[1]
|
||||
payload += "=" * (-len(payload) % 4)
|
||||
claims = json.loads(base64.urlsafe_b64decode(payload))
|
||||
exp = claims.get("exp", 0)
|
||||
if exp and time.time() > exp:
|
||||
logger.debug("Codex access token expired (exp=%s), skipping", exp)
|
||||
return None
|
||||
except Exception:
|
||||
pass # Non-JWT token or decode error — use as-is
|
||||
|
||||
return access_token.strip()
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read Codex auth for auxiliary client: %s", exc)
|
||||
return None
|
||||
@@ -475,11 +498,11 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
or (None, None) if none are configured.
|
||||
Returns (client, model) for the first provider with usable runtime
|
||||
credentials, or (None, None) if none are configured.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
|
||||
except ImportError:
|
||||
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
|
||||
return None, None
|
||||
@@ -487,34 +510,24 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
for provider_id, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# Check if any of the provider's env vars are set
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
if provider_id == "anthropic":
|
||||
return _try_anthropic()
|
||||
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
if env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
|
||||
base_url = "https://api.kimi.com/coding/v1"
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
api_key = str(creds.get("api_key", "")).strip()
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
@@ -659,10 +672,29 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
base_url = _ANTHROPIC_DEFAULT_BASE_URL
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from agent.anthropic_adapter import _is_oauth_token
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s)", model)
|
||||
real_client = build_anthropic_client(token, _ANTHROPIC_DEFAULT_BASE_URL)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, _ANTHROPIC_DEFAULT_BASE_URL), model
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
@@ -701,6 +733,8 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
|
||||
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
@@ -737,6 +771,10 @@ def _to_async_client(sync_client, model: str):
|
||||
base_lower = str(sync_client.base_url).lower()
|
||||
if "openrouter" in base_lower:
|
||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "api.githubcopilot.com" in base_lower:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
async_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in base_lower:
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
@@ -878,7 +916,7 @@ def resolve_provider_client(
|
||||
|
||||
# ── API-key providers from PROVIDER_REGISTRY ─────────────────────
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, _resolve_kimi_base_url
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
|
||||
except ImportError:
|
||||
logger.debug("hermes_cli.auth not available for provider %s", provider)
|
||||
return None, None
|
||||
@@ -897,26 +935,18 @@ def resolve_provider_client(
|
||||
final_model = model or default_model
|
||||
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
|
||||
|
||||
# Find the first configured API key
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
api_key = os.getenv(env_var, "").strip()
|
||||
if api_key:
|
||||
break
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
api_key = str(creds.get("api_key", "")).strip()
|
||||
if not api_key:
|
||||
tried_sources = list(pconfig.api_key_env_vars)
|
||||
if provider == "copilot":
|
||||
tried_sources.append("gh auth token")
|
||||
logger.warning("resolve_provider_client: provider %s has no API "
|
||||
"key configured (tried: %s)",
|
||||
provider, ", ".join(pconfig.api_key_env_vars))
|
||||
provider, ", ".join(tried_sources))
|
||||
return None, None
|
||||
|
||||
# Resolve base URL (env override → provider-specific logic → default)
|
||||
base_url_override = os.getenv(pconfig.base_url_env_var, "").strip() if pconfig.base_url_env_var else ""
|
||||
if provider == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, base_url_override)
|
||||
elif base_url_override:
|
||||
base_url = base_url_override
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
|
||||
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
|
||||
final_model = model or default_model
|
||||
@@ -925,6 +955,10 @@ def resolve_provider_client(
|
||||
headers = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
headers["User-Agent"] = "KimiCLI/1.0"
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
headers.update(copilot_default_headers())
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url,
|
||||
**({"default_headers": headers} if headers else {}))
|
||||
@@ -1167,6 +1201,54 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
|
||||
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
||||
_client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _force_close_async_httpx(client: Any) -> None:
|
||||
"""Mark the httpx AsyncClient inside an AsyncOpenAI client as closed.
|
||||
|
||||
This prevents ``AsyncHttpxClientWrapper.__del__`` from scheduling
|
||||
``aclose()`` on a (potentially closed) event loop, which causes
|
||||
``RuntimeError: Event loop is closed`` → prompt_toolkit's
|
||||
"Press ENTER to continue..." handler.
|
||||
|
||||
We intentionally do NOT run the full async close path — the
|
||||
connections will be dropped by the OS when the process exits.
|
||||
"""
|
||||
try:
|
||||
from httpx._client import ClientState
|
||||
inner = getattr(client, "_client", None)
|
||||
if inner is not None and not getattr(inner, "is_closed", True):
|
||||
inner._state = ClientState.CLOSED
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def shutdown_cached_clients() -> None:
|
||||
"""Close all cached clients (sync and async) to prevent event-loop errors.
|
||||
|
||||
Call this during CLI shutdown, *before* the event loop is closed, to
|
||||
avoid ``AsyncHttpxClientWrapper.__del__`` raising on a dead loop.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
with _client_cache_lock:
|
||||
for key, entry in list(_client_cache.items()):
|
||||
client = entry[0]
|
||||
if client is None:
|
||||
continue
|
||||
# Mark any async httpx transport as closed first (prevents __del__
|
||||
# from scheduling aclose() on a dead event loop).
|
||||
_force_close_async_httpx(client)
|
||||
# Sync clients: close the httpx connection pool cleanly.
|
||||
# Async clients: skip — we already neutered __del__ above.
|
||||
try:
|
||||
close_fn = getattr(client, "close", None)
|
||||
if close_fn and not inspect.iscoroutinefunction(close_fn):
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
_client_cache.clear()
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
@@ -1178,9 +1260,22 @@ def _get_cached_client(
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Get or create a cached client for the given provider."""
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default = _client_cache[cache_key]
|
||||
return cached_client, model or cached_default
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
if async_mode:
|
||||
# Async clients are bound to the event loop that created them.
|
||||
# A cached async client whose loop has been closed will raise
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
# Build outside the lock
|
||||
client, default_model = resolve_provider_client(
|
||||
provider,
|
||||
model,
|
||||
@@ -1189,7 +1284,20 @@ def _get_cached_client(
|
||||
explicit_api_key=api_key,
|
||||
)
|
||||
if client is not None:
|
||||
_client_cache[cache_key] = (client, default_model)
|
||||
# For async clients, remember which loop they were created on so we
|
||||
# can detect stale entries later.
|
||||
bound_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
bound_loop = _aio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
with _client_cache_lock:
|
||||
if cache_key not in _client_cache:
|
||||
_client_cache[cache_key] = (client, default_model, bound_loop)
|
||||
else:
|
||||
client, default_model, _ = _client_cache[cache_key]
|
||||
return client, model or default_model
|
||||
|
||||
|
||||
@@ -1234,12 +1342,16 @@ def _resolve_task_provider_model(
|
||||
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
||||
cfg_api_key = str(task_config.get("api_key", "")).strip() or None
|
||||
|
||||
# Backwards compat: compression section has its own keys
|
||||
if task == "compression" and not cfg_provider:
|
||||
# Backwards compat: compression section has its own keys.
|
||||
# The auxiliary.compression defaults to provider="auto", so treat
|
||||
# both None and "auto" as "not explicitly configured".
|
||||
if task == "compression" and (not cfg_provider or cfg_provider == "auto"):
|
||||
comp = config.get("compression", {}) if isinstance(config, dict) else {}
|
||||
if isinstance(comp, dict):
|
||||
cfg_provider = comp.get("summary_provider", "").strip() or None
|
||||
cfg_model = cfg_model or comp.get("summary_model", "").strip() or None
|
||||
_sbu = comp.get("summary_base_url") or ""
|
||||
cfg_base_url = cfg_base_url or _sbu.strip() or None
|
||||
|
||||
env_model = _get_auxiliary_env_override(task, "MODEL") if task else None
|
||||
resolved_model = model or env_model or cfg_model
|
||||
@@ -1387,8 +1499,18 @@ def call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
# Fallback: try openrouter
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
# When the user explicitly chose a non-OpenRouter provider but no
|
||||
# credentials were found, fail fast instead of silently routing
|
||||
# through OpenRouter (which causes confusing 404s).
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
# For auto/custom, fall back to OpenRouter
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
@@ -1470,7 +1592,14 @@ async def async_call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
|
||||
+386
-63
@@ -1,8 +1,16 @@
|
||||
"""Automatic context window compression for long conversations.
|
||||
|
||||
Self-contained class with its own OpenAI client for summarization.
|
||||
Uses Gemini Flash (cheap/fast) to summarize middle turns while
|
||||
Uses auxiliary model (cheap/fast) to summarize middle turns while
|
||||
protecting head and tail context.
|
||||
|
||||
Improvements over v1:
|
||||
- Structured summary template (Goal, Progress, Decisions, Files, Next Steps)
|
||||
- Iterative summary updates (preserves info across multiple compactions)
|
||||
- Token-budget tail protection instead of fixed message count
|
||||
- Tool output pruning before LLM summarization (cheap pre-pass)
|
||||
- Scaled summary budget (proportional to compressed content)
|
||||
- Richer tool call/result detail in summarizer input
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -27,12 +35,31 @@ SUMMARY_PREFIX = (
|
||||
)
|
||||
LEGACY_SUMMARY_PREFIX = "[CONTEXT SUMMARY]:"
|
||||
|
||||
# Minimum / maximum tokens for the summary output
|
||||
_MIN_SUMMARY_TOKENS = 2000
|
||||
_MAX_SUMMARY_TOKENS = 8000
|
||||
# Proportion of compressed content to allocate for summary
|
||||
_SUMMARY_RATIO = 0.20
|
||||
|
||||
# Token budget for tail protection (keep most-recent context)
|
||||
_DEFAULT_TAIL_TOKEN_BUDGET = 20_000
|
||||
|
||||
# Placeholder used when pruning old tool results
|
||||
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
|
||||
|
||||
# Chars per token rough estimate
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compresses conversation context when approaching the model's context limit.
|
||||
|
||||
Algorithm: protect first N + last N turns, summarize everything in between.
|
||||
Token tracking uses actual counts from API responses for accuracy.
|
||||
Algorithm:
|
||||
1. Prune old tool results (cheap, no LLM call)
|
||||
2. Protect head messages (system prompt + first exchange)
|
||||
3. Protect tail messages by token budget (most recent ~20K tokens)
|
||||
4. Summarize middle turns with structured LLM prompt
|
||||
5. On subsequent compactions, iteratively update the previous summary
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -45,18 +72,35 @@ class ContextCompressor:
|
||||
quiet_mode: bool = False,
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
config_context_length: int | None = None,
|
||||
provider: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.provider = provider
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(model, base_url=base_url)
|
||||
self.context_length = get_model_context_length(
|
||||
model, base_url=base_url, api_key=api_key,
|
||||
config_context_length=config_context_length,
|
||||
provider=provider,
|
||||
)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
|
||||
if not quiet_mode:
|
||||
logger.info(
|
||||
"Context compressor initialized: model=%s context_length=%d "
|
||||
"threshold=%d (%.0f%%) provider=%s base_url=%s",
|
||||
model, self.context_length, self.threshold_tokens,
|
||||
threshold_percent * 100, provider or "none", base_url or "none",
|
||||
)
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
@@ -65,6 +109,9 @@ class ContextCompressor:
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
|
||||
# Stores the previous compaction summary for iterative updates
|
||||
self._previous_summary: Optional[str] = None
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
@@ -91,53 +138,204 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
# ------------------------------------------------------------------
|
||||
# Tool output pruning (cheap pre-pass, no LLM call)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
def _prune_old_tool_results(
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
|
||||
Walks backward from the end, protecting the most recent
|
||||
``protect_tail_count`` messages. Older tool results get their
|
||||
content replaced with a placeholder string.
|
||||
|
||||
Returns (pruned_messages, pruned_count).
|
||||
"""
|
||||
if not messages:
|
||||
return messages, 0
|
||||
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
|
||||
continue
|
||||
# Only prune if the content is substantial (>200 chars)
|
||||
if len(content) > 200:
|
||||
result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
|
||||
pruned += 1
|
||||
|
||||
return result, pruned
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Summarization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_summary_budget(self, turns_to_summarize: List[Dict[str, Any]]) -> int:
|
||||
"""Scale summary token budget with the amount of content being compressed."""
|
||||
content_tokens = estimate_messages_tokens_rough(turns_to_summarize)
|
||||
budget = int(content_tokens * _SUMMARY_RATIO)
|
||||
return max(_MIN_SUMMARY_TOKENS, min(budget, _MAX_SUMMARY_TOKENS))
|
||||
|
||||
def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str:
|
||||
"""Serialize conversation turns into labeled text for the summarizer.
|
||||
|
||||
Includes tool call arguments and result content (up to 3000 chars
|
||||
per message) so the summarizer can preserve specific details like
|
||||
file paths, commands, and outputs.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
# Tool results: keep more content than before (3000 chars)
|
||||
if role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[TOOL RESULT {tool_id}]: {content}")
|
||||
continue
|
||||
|
||||
# Assistant messages: include tool call names AND arguments
|
||||
if role == "assistant":
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name", "?")
|
||||
args = fn.get("arguments", "")
|
||||
# Truncate long arguments but keep enough for context
|
||||
if len(args) > 500:
|
||||
args = args[:400] + "..."
|
||||
tc_parts.append(f" {name}({args})")
|
||||
else:
|
||||
fn = getattr(tc, "function", None)
|
||||
name = getattr(fn, "name", "?") if fn else "?"
|
||||
tc_parts.append(f" {name}(...)")
|
||||
content += "\n[Tool calls:\n" + "\n".join(tc_parts) + "\n]"
|
||||
parts.append(f"[ASSISTANT]: {content}")
|
||||
continue
|
||||
|
||||
# User and other roles
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a structured summary of conversation turns.
|
||||
|
||||
Uses a structured template (Goal, Progress, Decisions, Files, Next Steps)
|
||||
inspired by Pi-mono and OpenCode. When a previous summary exists,
|
||||
generates an iterative update instead of summarizing from scratch.
|
||||
|
||||
Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
if len(content) > 2000:
|
||||
content = content[:1000] + "\n...[truncated]...\n" + content[-500:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tool_names = [tc.get("function", {}).get("name", "?") for tc in tool_calls if isinstance(tc, dict)]
|
||||
content += f"\n[Tool calls: {', '.join(tool_names)}]"
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
summary_budget = self._compute_summary_budget(turns_to_summarize)
|
||||
content_to_summarize = self._serialize_for_summary(turns_to_summarize)
|
||||
|
||||
content_to_summarize = "\n\n".join(parts)
|
||||
prompt = f"""Create a concise handoff summary for a later assistant that will continue this conversation after earlier turns are compacted.
|
||||
if self._previous_summary:
|
||||
# Iterative update: preserve existing info, add new progress
|
||||
prompt = f"""You are updating a context compaction summary. A previous compaction produced the summary below. New conversation turns have occurred since then and need to be incorporated.
|
||||
|
||||
Describe:
|
||||
1. What actions were taken (tool calls, searches, file operations)
|
||||
2. Key information or results obtained
|
||||
3. Important decisions, constraints, or user preferences
|
||||
4. Relevant data, file names, outputs, or next steps needed to continue
|
||||
PREVIOUS SUMMARY:
|
||||
{self._previous_summary}
|
||||
|
||||
Keep it factual, concise, and focused on helping the next assistant resume without repeating work. Target ~{self.summary_target_tokens} tokens.
|
||||
NEW TURNS TO INCORPORATE:
|
||||
{content_to_summarize}
|
||||
|
||||
Update the summary using this exact structure. PRESERVE all existing information that is still relevant. ADD new progress. Move items from "In Progress" to "Done" when completed. Remove information only if it is clearly obsolete.
|
||||
|
||||
## Goal
|
||||
[What the user is trying to accomplish — preserve from previous summary, update if goal evolved]
|
||||
|
||||
## Constraints & Preferences
|
||||
[User preferences, coding style, constraints, important decisions — accumulate across compactions]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
[Completed work — include specific file paths, commands run, results obtained]
|
||||
### In Progress
|
||||
[Work currently underway]
|
||||
### Blocked
|
||||
[Any blockers or issues encountered]
|
||||
|
||||
## Key Decisions
|
||||
[Important technical decisions and why they were made]
|
||||
|
||||
## Relevant Files
|
||||
[Files read, modified, or created — with brief note on each. Accumulate across compactions.]
|
||||
|
||||
## Next Steps
|
||||
[What needs to happen next to continue the work]
|
||||
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
else:
|
||||
# First compaction: summarize from scratch
|
||||
prompt = f"""Create a structured handoff summary for a later assistant that will continue this conversation after earlier turns are compacted.
|
||||
|
||||
---
|
||||
TURNS TO SUMMARIZE:
|
||||
{content_to_summarize}
|
||||
---
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix; the system will add the handoff wrapper."""
|
||||
Use this exact structure:
|
||||
|
||||
## Goal
|
||||
[What the user is trying to accomplish]
|
||||
|
||||
## Constraints & Preferences
|
||||
[User preferences, coding style, constraints, important decisions]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
[Completed work — include specific file paths, commands run, results obtained]
|
||||
### In Progress
|
||||
[Work currently underway]
|
||||
### Blocked
|
||||
[Any blockers or issues encountered]
|
||||
|
||||
## Key Decisions
|
||||
[Important technical decisions and why they were made]
|
||||
|
||||
## Relevant Files
|
||||
[Files read, modified, or created — with brief note on each]
|
||||
|
||||
## Next Steps
|
||||
[What needs to happen next to continue the work]
|
||||
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
# Use the centralized LLM router — handles provider resolution,
|
||||
# auth, and fallback internally.
|
||||
try:
|
||||
call_kwargs = {
|
||||
"task": "compression",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": self.summary_target_tokens * 2,
|
||||
"timeout": 30.0,
|
||||
"max_tokens": summary_budget * 2,
|
||||
"timeout": 45.0,
|
||||
}
|
||||
if self.summary_model:
|
||||
call_kwargs["model"] = self.summary_model
|
||||
@@ -147,6 +345,8 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content else ""
|
||||
summary = content.strip()
|
||||
# Store for iterative updates on next compaction
|
||||
self._previous_summary = summary
|
||||
return self._with_summary_prefix(summary)
|
||||
except RuntimeError:
|
||||
logging.warning("Context compression: no provider available for "
|
||||
@@ -251,56 +451,149 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
If the message just before ``idx`` is an assistant message with
|
||||
tool_calls, those tool results will start at ``idx`` and would be
|
||||
separated from their parent. Move backwards to include the whole
|
||||
group in the summarised region.
|
||||
If the boundary falls in the middle of a tool-result group (i.e.
|
||||
there are consecutive tool messages before ``idx``), walk backward
|
||||
past all of them to find the parent assistant message. If found,
|
||||
move the boundary before the assistant so the entire
|
||||
assistant + tool_results group is included in the summarised region
|
||||
rather than being split (which causes silent data loss when
|
||||
``_sanitize_tool_pairs`` removes the orphaned tail results).
|
||||
"""
|
||||
if idx <= 0 or idx >= len(messages):
|
||||
return idx
|
||||
prev = messages[idx - 1]
|
||||
if prev.get("role") == "assistant" and prev.get("tool_calls"):
|
||||
# The results for this assistant turn sit at idx..idx+k.
|
||||
# Include the assistant message in the summarised region too.
|
||||
idx -= 1
|
||||
# Walk backward past consecutive tool results
|
||||
check = idx - 1
|
||||
while check >= 0 and messages[check].get("role") == "tool":
|
||||
check -= 1
|
||||
# If we landed on the parent assistant with tool_calls, pull the
|
||||
# boundary before it so the whole group gets summarised together.
|
||||
if check >= 0 and messages[check].get("role") == "assistant" and messages[check].get("tool_calls"):
|
||||
idx = check
|
||||
return idx
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tail protection by token budget
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_tail_cut_by_tokens(
|
||||
self, messages: List[Dict[str, Any]], head_end: int,
|
||||
token_budget: int = _DEFAULT_TAIL_TOKEN_BUDGET,
|
||||
) -> int:
|
||||
"""Walk backward from the end of messages, accumulating tokens until
|
||||
the budget is reached. Returns the index where the tail starts.
|
||||
|
||||
Never cuts inside a tool_call/result group. Falls back to the old
|
||||
``protect_last_n`` if the budget would protect fewer messages.
|
||||
"""
|
||||
n = len(messages)
|
||||
min_tail = self.protect_last_n
|
||||
accumulated = 0
|
||||
cut_idx = n # start from beyond the end
|
||||
|
||||
for i in range(n - 1, head_end - 1, -1):
|
||||
msg = messages[i]
|
||||
content = msg.get("content") or ""
|
||||
msg_tokens = len(content) // _CHARS_PER_TOKEN + 10 # +10 for role/metadata
|
||||
# Include tool call arguments in estimate
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
cut_idx = i
|
||||
|
||||
# Ensure we protect at least protect_last_n messages
|
||||
fallback_cut = n - min_tail
|
||||
if cut_idx > fallback_cut:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# If the token budget would protect everything (small conversations),
|
||||
# fall back to the fixed protect_last_n approach so compression can
|
||||
# still remove middle turns.
|
||||
if cut_idx <= head_end:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# Align to avoid splitting tool groups
|
||||
cut_idx = self._align_boundary_backward(messages, cut_idx)
|
||||
|
||||
return max(cut_idx, head_end + 1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main compression entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
Algorithm:
|
||||
1. Prune old tool results (cheap pre-pass, no LLM call)
|
||||
2. Protect head messages (system prompt + first exchange)
|
||||
3. Find tail boundary by token budget (~20K tokens of recent context)
|
||||
4. Summarize middle turns with structured LLM prompt
|
||||
5. On re-compression, iteratively update the previous summary
|
||||
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
if not self.quiet_mode:
|
||||
print(f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})")
|
||||
logger.warning(
|
||||
"Cannot compress: only %d messages (need > %d)",
|
||||
n_messages,
|
||||
self.protect_first_n + self.protect_last_n + 1,
|
||||
)
|
||||
return messages
|
||||
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
# Phase 1: Prune old tool results (cheap, no LLM call)
|
||||
messages, pruned_count = self._prune_old_tool_results(
|
||||
messages, protect_tail_count=self.protect_last_n * 3,
|
||||
)
|
||||
if pruned_count and not self.quiet_mode:
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
|
||||
# Phase 2: Determine boundaries
|
||||
compress_start = self.protect_first_n
|
||||
compress_end = n_messages - self.protect_last_n
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
|
||||
# Use token-budget tail protection instead of fixed message count
|
||||
compress_end = self._find_tail_cut_by_tokens(messages, compress_start)
|
||||
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
logger.info(
|
||||
"Context compression triggered (%d tokens >= %d threshold)",
|
||||
display_tokens,
|
||||
self.threshold_tokens,
|
||||
)
|
||||
logger.info(
|
||||
"Model context limit: %d tokens (%.0f%% = %d)",
|
||||
self.context_length,
|
||||
self.threshold_percent * 100,
|
||||
self.threshold_tokens,
|
||||
)
|
||||
tail_msgs = n_messages - compress_end
|
||||
logger.info(
|
||||
"Summarizing turns %d-%d (%d turns), protecting %d head + %d tail messages",
|
||||
compress_start + 1,
|
||||
compress_end,
|
||||
len(turns_to_summarize),
|
||||
compress_start,
|
||||
tail_msgs,
|
||||
)
|
||||
|
||||
# Phase 3: Generate structured summary
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
|
||||
# Phase 4: Assemble compressed message list
|
||||
compressed = []
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
@@ -311,16 +604,41 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
)
|
||||
compressed.append(msg)
|
||||
|
||||
_merge_summary_into_tail = False
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
else:
|
||||
# Both roles would create consecutive same-role messages
|
||||
# (e.g. head=assistant, tail=user — neither role works).
|
||||
# Merge the summary into the first tail message instead
|
||||
# of inserting a standalone message that breaks alternation.
|
||||
_merge_summary_into_tail = True
|
||||
if not _merge_summary_into_tail:
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||
logger.warning("No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
compressed.append(messages[i].copy())
|
||||
msg = messages[i].copy()
|
||||
if _merge_summary_into_tail and i == compress_end:
|
||||
original = msg.get("content") or ""
|
||||
msg["content"] = summary + "\n\n" + original
|
||||
_merge_summary_into_tail = False
|
||||
compressed.append(msg)
|
||||
|
||||
self.compression_count += 1
|
||||
|
||||
@@ -329,7 +647,12 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
print(f" ✅ Compressed: {n_messages} → {len(compressed)} messages (~{saved_estimate:,} tokens saved)")
|
||||
print(f" 💡 Compression #{self.compression_count} complete")
|
||||
logger.info(
|
||||
"Compressed: %d -> %d messages (~%d tokens saved)",
|
||||
n_messages,
|
||||
len(compressed),
|
||||
saved_estimate,
|
||||
)
|
||||
logger.info("Compression #%d complete", self.compression_count)
|
||||
|
||||
return compressed
|
||||
|
||||
@@ -0,0 +1,485 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube")
|
||||
_SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
|
||||
_SENSITIVE_HOME_FILES = (
|
||||
Path(".ssh") / "authorized_keys",
|
||||
Path(".ssh") / "id_rsa",
|
||||
Path(".ssh") / "id_ed25519",
|
||||
Path(".ssh") / "config",
|
||||
Path(".bashrc"),
|
||||
Path(".zshrc"),
|
||||
Path(".profile"),
|
||||
Path(".bash_profile"),
|
||||
Path(".zprofile"),
|
||||
Path(".netrc"),
|
||||
Path(".pgpass"),
|
||||
Path(".npmrc"),
|
||||
Path(".pypirc"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextReference:
|
||||
raw: str
|
||||
kind: str
|
||||
target: str
|
||||
start: int
|
||||
end: int
|
||||
line_start: int | None = None
|
||||
line_end: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextReferenceResult:
|
||||
message: str
|
||||
original_message: str
|
||||
references: list[ContextReference] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
injected_tokens: int = 0
|
||||
expanded: bool = False
|
||||
blocked: bool = False
|
||||
|
||||
|
||||
def parse_context_references(message: str) -> list[ContextReference]:
|
||||
refs: list[ContextReference] = []
|
||||
if not message:
|
||||
return refs
|
||||
|
||||
for match in REFERENCE_PATTERN.finditer(message):
|
||||
simple = match.group("simple")
|
||||
if simple:
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=simple,
|
||||
target="",
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
kind = match.group("kind")
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=kind,
|
||||
target=target,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
)
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
def preprocess_context_references(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
coro = preprocess_context_references_async(
|
||||
message,
|
||||
cwd=cwd,
|
||||
context_length=context_length,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root,
|
||||
)
|
||||
# Safe for both CLI (no loop) and gateway (loop already running).
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
async def preprocess_context_references_async(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
refs = parse_context_references(message)
|
||||
if not refs:
|
||||
return ContextReferenceResult(message=message, original_message=message)
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
# Default to the current working directory so @ references cannot escape
|
||||
# the active workspace unless a caller explicitly widens the root.
|
||||
allowed_root_path = (
|
||||
Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path
|
||||
)
|
||||
warnings: list[str] = []
|
||||
blocks: list[str] = []
|
||||
injected_tokens = 0
|
||||
|
||||
for ref in refs:
|
||||
warning, block = await _expand_reference(
|
||||
ref,
|
||||
cwd_path,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root_path,
|
||||
)
|
||||
if warning:
|
||||
warnings.append(warning)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
injected_tokens += estimate_tokens_rough(block)
|
||||
|
||||
hard_limit = max(1, int(context_length * 0.50))
|
||||
soft_limit = max(1, int(context_length * 0.25))
|
||||
if injected_tokens > hard_limit:
|
||||
warnings.append(
|
||||
f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
|
||||
)
|
||||
return ContextReferenceResult(
|
||||
message=message,
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=False,
|
||||
blocked=True,
|
||||
)
|
||||
|
||||
if injected_tokens > soft_limit:
|
||||
warnings.append(
|
||||
f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
|
||||
)
|
||||
|
||||
stripped = _remove_reference_tokens(message, refs)
|
||||
final = stripped
|
||||
if warnings:
|
||||
final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
|
||||
if blocks:
|
||||
final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
|
||||
|
||||
return ContextReferenceResult(
|
||||
message=final.strip(),
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=bool(blocks or warnings),
|
||||
blocked=False,
|
||||
)
|
||||
|
||||
|
||||
async def _expand_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
try:
|
||||
if ref.kind == "file":
|
||||
return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "folder":
|
||||
return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "diff":
|
||||
return _expand_git_reference(ref, cwd, ["diff"], "git diff")
|
||||
if ref.kind == "staged":
|
||||
return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
|
||||
if ref.kind == "git":
|
||||
count = max(1, min(int(ref.target or "1"), 10))
|
||||
return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
|
||||
if ref.kind == "url":
|
||||
content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
|
||||
if not content:
|
||||
return f"{ref.raw}: no content extracted", None
|
||||
return None, f"🌐 {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
|
||||
except Exception as exc:
|
||||
return f"{ref.raw}: {exc}", None
|
||||
|
||||
return f"{ref.raw}: unsupported reference type", None
|
||||
|
||||
|
||||
def _expand_file_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: file not found", None
|
||||
if not path.is_file():
|
||||
return f"{ref.raw}: path is not a file", None
|
||||
if _is_binary_file(path):
|
||||
return f"{ref.raw}: binary files are not supported", None
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if ref.line_start is not None:
|
||||
lines = text.splitlines()
|
||||
start_idx = max(ref.line_start - 1, 0)
|
||||
end_idx = min(ref.line_end or ref.line_start, len(lines))
|
||||
text = "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
lang = _code_fence_language(path)
|
||||
label = ref.raw
|
||||
return None, f"📄 {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
|
||||
|
||||
|
||||
def _expand_folder_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: folder not found", None
|
||||
if not path.is_dir():
|
||||
return f"{ref.raw}: path is not a folder", None
|
||||
|
||||
listing = _build_folder_listing(path, cwd)
|
||||
return None, f"📁 {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
|
||||
|
||||
|
||||
def _expand_git_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
args: list[str],
|
||||
label: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip() or "git command failed"
|
||||
return f"{ref.raw}: {stderr}", None
|
||||
content = result.stdout.strip()
|
||||
if not content:
|
||||
content = "(no output)"
|
||||
return None, f"🧾 {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
|
||||
|
||||
|
||||
async def _fetch_url_content(
|
||||
url: str,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
) -> str:
|
||||
fetcher = url_fetcher or _default_url_fetcher
|
||||
content = fetcher(url)
|
||||
if inspect.isawaitable(content):
|
||||
content = await content
|
||||
return str(content or "").strip()
|
||||
|
||||
|
||||
async def _default_url_fetcher(url: str) -> str:
|
||||
from tools.web_tools import web_extract_tool
|
||||
|
||||
raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
|
||||
payload = json.loads(raw)
|
||||
docs = payload.get("data", {}).get("documents", [])
|
||||
if not docs:
|
||||
return ""
|
||||
doc = docs[0]
|
||||
return str(doc.get("content") or doc.get("raw_content") or "").strip()
|
||||
|
||||
|
||||
def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
|
||||
path = Path(os.path.expanduser(target))
|
||||
if not path.is_absolute():
|
||||
path = cwd / path
|
||||
resolved = path.resolve()
|
||||
if allowed_root is not None:
|
||||
try:
|
||||
resolved.relative_to(allowed_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError("path is outside the allowed workspace") from exc
|
||||
return resolved
|
||||
|
||||
|
||||
def _ensure_reference_path_allowed(path: Path) -> None:
|
||||
home = Path(os.path.expanduser("~")).resolve()
|
||||
hermes_home = Path(
|
||||
os.getenv("HERMES_HOME", str(home / ".hermes"))
|
||||
).expanduser().resolve()
|
||||
|
||||
blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES}
|
||||
blocked_exact.add(hermes_home / ".env")
|
||||
blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS]
|
||||
blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS)
|
||||
|
||||
if path in blocked_exact:
|
||||
raise ValueError("path is a sensitive credential file and cannot be attached")
|
||||
|
||||
for blocked_dir in blocked_dirs:
|
||||
try:
|
||||
path.relative_to(blocked_dir)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached")
|
||||
|
||||
|
||||
def _strip_trailing_punctuation(value: str) -> str:
|
||||
stripped = value.rstrip(TRAILING_PUNCTUATION)
|
||||
while stripped.endswith((")", "]", "}")):
|
||||
closer = stripped[-1]
|
||||
opener = {")": "(", "]": "[", "}": "{"}[closer]
|
||||
if stripped.count(closer) > stripped.count(opener):
|
||||
stripped = stripped[:-1]
|
||||
continue
|
||||
break
|
||||
return stripped
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
for ref in refs:
|
||||
pieces.append(message[cursor:ref.start])
|
||||
cursor = ref.end
|
||||
pieces.append(message[cursor:])
|
||||
text = "".join(pieces)
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
text = re.sub(r"\s+([,.;:!?])", r"\1", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _is_binary_file(path: Path) -> bool:
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
if mime and not mime.startswith("text/") and not any(
|
||||
path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
|
||||
):
|
||||
return True
|
||||
chunk = path.read_bytes()[:4096]
|
||||
return b"\x00" in chunk
|
||||
|
||||
|
||||
def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
|
||||
lines = [f"{path.relative_to(cwd)}/"]
|
||||
entries = _iter_visible_entries(path, cwd, limit=limit)
|
||||
for entry in entries:
|
||||
rel = entry.relative_to(cwd)
|
||||
indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
|
||||
if entry.is_dir():
|
||||
lines.append(f"{indent}- {entry.name}/")
|
||||
else:
|
||||
meta = _file_metadata(entry)
|
||||
lines.append(f"{indent}- {entry.name} ({meta})")
|
||||
if len(entries) >= limit:
|
||||
lines.append("- ...")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
|
||||
rg_entries = _rg_files(path, cwd, limit=limit)
|
||||
if rg_entries is not None:
|
||||
output: list[Path] = []
|
||||
seen_dirs: set[Path] = set()
|
||||
for rel in rg_entries:
|
||||
full = cwd / rel
|
||||
for parent in full.parents:
|
||||
if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
|
||||
continue
|
||||
seen_dirs.add(parent)
|
||||
output.append(parent)
|
||||
output.append(full)
|
||||
return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
|
||||
|
||||
output = []
|
||||
for root, dirs, files in os.walk(path):
|
||||
dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
|
||||
files = sorted(f for f in files if not f.startswith("."))
|
||||
root_path = Path(root)
|
||||
for d in dirs:
|
||||
output.append(root_path / d)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
for f in files:
|
||||
output.append(root_path / f)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rg", "--files", str(path.relative_to(cwd))],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
|
||||
return files[:limit]
|
||||
|
||||
|
||||
def _file_metadata(path: Path) -> str:
|
||||
if _is_binary_file(path):
|
||||
return f"{path.stat().st_size} bytes"
|
||||
try:
|
||||
line_count = path.read_text(encoding="utf-8").count("\n") + 1
|
||||
except Exception:
|
||||
return f"{path.stat().st_size} bytes"
|
||||
return f"{line_count} lines"
|
||||
|
||||
|
||||
def _code_fence_language(path: Path) -> str:
|
||||
mapping = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".json": "json",
|
||||
".md": "markdown",
|
||||
".sh": "bash",
|
||||
".yml": "yaml",
|
||||
".yaml": "yaml",
|
||||
".toml": "toml",
|
||||
}
|
||||
return mapping.get(path.suffix.lower(), "")
|
||||
@@ -0,0 +1,447 @@
|
||||
"""OpenAI-compatible shim that forwards Hermes requests to `copilot --acp`.
|
||||
|
||||
This adapter lets Hermes treat the GitHub Copilot ACP server as a chat-style
|
||||
backend. Each request starts a short-lived ACP session, sends the formatted
|
||||
conversation as a single prompt, collects text chunks, and converts the result
|
||||
back into the minimal shape Hermes expects from an OpenAI client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
ACP_MARKER_BASE_URL = "acp://copilot"
|
||||
_DEFAULT_TIMEOUT_SECONDS = 900.0
|
||||
|
||||
|
||||
def _resolve_command() -> str:
|
||||
return (
|
||||
os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip()
|
||||
or os.getenv("COPILOT_CLI_PATH", "").strip()
|
||||
or "copilot"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_args() -> list[str]:
|
||||
raw = os.getenv("HERMES_COPILOT_ACP_ARGS", "").strip()
|
||||
if not raw:
|
||||
return ["--acp", "--stdio"]
|
||||
return shlex.split(raw)
|
||||
|
||||
|
||||
def _jsonrpc_error(message_id: Any, code: int, message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _format_messages_as_prompt(messages: list[dict[str, Any]], model: str | None = None) -> str:
|
||||
sections: list[str] = [
|
||||
"You are being used as the active ACP agent backend for Hermes.",
|
||||
"Use your own ACP capabilities and respond directly in natural language.",
|
||||
"Do not emit OpenAI tool-call JSON.",
|
||||
]
|
||||
if model:
|
||||
sections.append(f"Hermes requested model hint: {model}")
|
||||
|
||||
transcript: list[str] = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
role = str(message.get("role") or "unknown").strip().lower()
|
||||
if role == "tool":
|
||||
role = "tool"
|
||||
elif role not in {"system", "user", "assistant"}:
|
||||
role = "context"
|
||||
|
||||
content = message.get("content")
|
||||
rendered = _render_message_content(content)
|
||||
if not rendered:
|
||||
continue
|
||||
|
||||
label = {
|
||||
"system": "System",
|
||||
"user": "User",
|
||||
"assistant": "Assistant",
|
||||
"tool": "Tool",
|
||||
"context": "Context",
|
||||
}.get(role, role.title())
|
||||
transcript.append(f"{label}:\n{rendered}")
|
||||
|
||||
if transcript:
|
||||
sections.append("Conversation transcript:\n\n" + "\n\n".join(transcript))
|
||||
|
||||
sections.append("Continue the conversation from the latest user request.")
|
||||
return "\n\n".join(section.strip() for section in sections if section and section.strip())
|
||||
|
||||
|
||||
def _render_message_content(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, dict):
|
||||
if "text" in content:
|
||||
return str(content.get("text") or "").strip()
|
||||
if "content" in content and isinstance(content.get("content"), str):
|
||||
return str(content.get("content") or "").strip()
|
||||
return json.dumps(content, ensure_ascii=True)
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
parts.append(text.strip())
|
||||
return "\n".join(parts).strip()
|
||||
return str(content).strip()
|
||||
|
||||
|
||||
def _ensure_path_within_cwd(path_text: str, cwd: str) -> Path:
|
||||
candidate = Path(path_text)
|
||||
if not candidate.is_absolute():
|
||||
raise PermissionError("ACP file-system paths must be absolute.")
|
||||
resolved = candidate.resolve()
|
||||
root = Path(cwd).resolve()
|
||||
try:
|
||||
resolved.relative_to(root)
|
||||
except ValueError as exc:
|
||||
raise PermissionError(f"Path '{resolved}' is outside the session cwd '{root}'.") from exc
|
||||
return resolved
|
||||
|
||||
|
||||
class _ACPChatCompletions:
|
||||
def __init__(self, client: "CopilotACPClient"):
|
||||
self._client = client
|
||||
|
||||
def create(self, **kwargs: Any) -> Any:
|
||||
return self._client._create_chat_completion(**kwargs)
|
||||
|
||||
|
||||
class _ACPChatNamespace:
|
||||
def __init__(self, client: "CopilotACPClient"):
|
||||
self.completions = _ACPChatCompletions(client)
|
||||
|
||||
|
||||
class CopilotACPClient:
|
||||
"""Minimal OpenAI-client-compatible facade for Copilot ACP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
acp_command: str | None = None,
|
||||
acp_args: list[str] | None = None,
|
||||
acp_cwd: str | None = None,
|
||||
command: str | None = None,
|
||||
args: list[str] | None = None,
|
||||
**_: Any,
|
||||
):
|
||||
self.api_key = api_key or "copilot-acp"
|
||||
self.base_url = base_url or ACP_MARKER_BASE_URL
|
||||
self._default_headers = dict(default_headers or {})
|
||||
self._acp_command = acp_command or command or _resolve_command()
|
||||
self._acp_args = list(acp_args or args or _resolve_args())
|
||||
self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve())
|
||||
self.chat = _ACPChatNamespace(self)
|
||||
self.is_closed = False
|
||||
self._active_process: subprocess.Popen[str] | None = None
|
||||
self._active_process_lock = threading.Lock()
|
||||
|
||||
def close(self) -> None:
|
||||
proc: subprocess.Popen[str] | None
|
||||
with self._active_process_lock:
|
||||
proc = self._active_process
|
||||
self._active_process = None
|
||||
self.is_closed = True
|
||||
if proc is None:
|
||||
return
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=2)
|
||||
except Exception:
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _create_chat_completion(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
timeout: float | None = None,
|
||||
**_: Any,
|
||||
) -> Any:
|
||||
prompt_text = _format_messages_as_prompt(messages or [], model=model)
|
||||
response_text, reasoning_text = self._run_prompt(
|
||||
prompt_text,
|
||||
timeout_seconds=float(timeout or _DEFAULT_TIMEOUT_SECONDS),
|
||||
)
|
||||
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
|
||||
)
|
||||
assistant_message = SimpleNamespace(
|
||||
content=response_text,
|
||||
tool_calls=[],
|
||||
reasoning=reasoning_text or None,
|
||||
reasoning_content=reasoning_text or None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=assistant_message, finish_reason="stop")
|
||||
return SimpleNamespace(
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
model=model or "copilot-acp",
|
||||
)
|
||||
|
||||
def _run_prompt(self, prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]:
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[self._acp_command] + self._acp_args,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=self._acp_cwd,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
f"Could not start Copilot ACP command '{self._acp_command}'. "
|
||||
"Install GitHub Copilot CLI or set HERMES_COPILOT_ACP_COMMAND/COPILOT_CLI_PATH."
|
||||
) from exc
|
||||
|
||||
if proc.stdin is None or proc.stdout is None:
|
||||
proc.kill()
|
||||
raise RuntimeError("Copilot ACP process did not expose stdin/stdout pipes.")
|
||||
|
||||
self.is_closed = False
|
||||
with self._active_process_lock:
|
||||
self._active_process = proc
|
||||
|
||||
inbox: queue.Queue[dict[str, Any]] = queue.Queue()
|
||||
stderr_tail: deque[str] = deque(maxlen=40)
|
||||
|
||||
def _stdout_reader() -> None:
|
||||
for line in proc.stdout:
|
||||
try:
|
||||
inbox.put(json.loads(line))
|
||||
except Exception:
|
||||
inbox.put({"raw": line.rstrip("\n")})
|
||||
|
||||
def _stderr_reader() -> None:
|
||||
if proc.stderr is None:
|
||||
return
|
||||
for line in proc.stderr:
|
||||
stderr_tail.append(line.rstrip("\n"))
|
||||
|
||||
out_thread = threading.Thread(target=_stdout_reader, daemon=True)
|
||||
err_thread = threading.Thread(target=_stderr_reader, daemon=True)
|
||||
out_thread.start()
|
||||
err_thread.start()
|
||||
|
||||
next_id = 0
|
||||
|
||||
def _request(method: str, params: dict[str, Any], *, text_parts: list[str] | None = None, reasoning_parts: list[str] | None = None) -> Any:
|
||||
nonlocal next_id
|
||||
next_id += 1
|
||||
request_id = next_id
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
proc.stdin.write(json.dumps(payload) + "\n")
|
||||
proc.stdin.flush()
|
||||
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
if proc.poll() is not None:
|
||||
break
|
||||
try:
|
||||
msg = inbox.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if self._handle_server_message(
|
||||
msg,
|
||||
process=proc,
|
||||
cwd=self._acp_cwd,
|
||||
text_parts=text_parts,
|
||||
reasoning_parts=reasoning_parts,
|
||||
):
|
||||
continue
|
||||
|
||||
if msg.get("id") != request_id:
|
||||
continue
|
||||
if "error" in msg:
|
||||
err = msg.get("error") or {}
|
||||
raise RuntimeError(
|
||||
f"Copilot ACP {method} failed: {err.get('message') or err}"
|
||||
)
|
||||
return msg.get("result")
|
||||
|
||||
stderr_text = "\n".join(stderr_tail).strip()
|
||||
if proc.poll() is not None and stderr_text:
|
||||
raise RuntimeError(f"Copilot ACP process exited early: {stderr_text}")
|
||||
raise TimeoutError(f"Timed out waiting for Copilot ACP response to {method}.")
|
||||
|
||||
try:
|
||||
_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": 1,
|
||||
"clientCapabilities": {
|
||||
"fs": {
|
||||
"readTextFile": True,
|
||||
"writeTextFile": True,
|
||||
}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "hermes-agent",
|
||||
"title": "Hermes Agent",
|
||||
"version": "0.0.0",
|
||||
},
|
||||
},
|
||||
)
|
||||
session = _request(
|
||||
"session/new",
|
||||
{
|
||||
"cwd": self._acp_cwd,
|
||||
"mcpServers": [],
|
||||
},
|
||||
) or {}
|
||||
session_id = str(session.get("sessionId") or "").strip()
|
||||
if not session_id:
|
||||
raise RuntimeError("Copilot ACP did not return a sessionId.")
|
||||
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
_request(
|
||||
"session/prompt",
|
||||
{
|
||||
"sessionId": session_id,
|
||||
"prompt": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
text_parts=text_parts,
|
||||
reasoning_parts=reasoning_parts,
|
||||
)
|
||||
return "".join(text_parts), "".join(reasoning_parts)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def _handle_server_message(
|
||||
self,
|
||||
msg: dict[str, Any],
|
||||
*,
|
||||
process: subprocess.Popen[str],
|
||||
cwd: str,
|
||||
text_parts: list[str] | None,
|
||||
reasoning_parts: list[str] | None,
|
||||
) -> bool:
|
||||
method = msg.get("method")
|
||||
if not isinstance(method, str):
|
||||
return False
|
||||
|
||||
if method == "session/update":
|
||||
params = msg.get("params") or {}
|
||||
update = params.get("update") or {}
|
||||
kind = str(update.get("sessionUpdate") or "").strip()
|
||||
content = update.get("content") or {}
|
||||
chunk_text = ""
|
||||
if isinstance(content, dict):
|
||||
chunk_text = str(content.get("text") or "")
|
||||
if kind == "agent_message_chunk" and chunk_text and text_parts is not None:
|
||||
text_parts.append(chunk_text)
|
||||
elif kind == "agent_thought_chunk" and chunk_text and reasoning_parts is not None:
|
||||
reasoning_parts.append(chunk_text)
|
||||
return True
|
||||
|
||||
if process.stdin is None:
|
||||
return True
|
||||
|
||||
message_id = msg.get("id")
|
||||
params = msg.get("params") or {}
|
||||
|
||||
if method == "session/request_permission":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {
|
||||
"outcome": {
|
||||
"outcome": "allow_once",
|
||||
}
|
||||
},
|
||||
}
|
||||
elif method == "fs/read_text_file":
|
||||
try:
|
||||
path = _ensure_path_within_cwd(str(params.get("path") or ""), cwd)
|
||||
content = path.read_text() if path.exists() else ""
|
||||
line = params.get("line")
|
||||
limit = params.get("limit")
|
||||
if isinstance(line, int) and line > 1:
|
||||
lines = content.splitlines(keepends=True)
|
||||
start = line - 1
|
||||
end = start + limit if isinstance(limit, int) and limit > 0 else None
|
||||
content = "".join(lines[start:end])
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
except Exception as exc:
|
||||
response = _jsonrpc_error(message_id, -32602, str(exc))
|
||||
elif method == "fs/write_text_file":
|
||||
try:
|
||||
path = _ensure_path_within_cwd(str(params.get("path") or ""), cwd)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(str(params.get("content") or ""))
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
response = _jsonrpc_error(message_id, -32602, str(exc))
|
||||
else:
|
||||
response = _jsonrpc_error(
|
||||
message_id,
|
||||
-32601,
|
||||
f"ACP client method '{method}' is not supported by Hermes yet.",
|
||||
)
|
||||
|
||||
process.stdin.write(json.dumps(response) + "\n")
|
||||
process.stdin.flush()
|
||||
return True
|
||||
+139
-5
@@ -59,6 +59,32 @@ def get_skin_tool_prefix() -> str:
|
||||
return "┊"
|
||||
|
||||
|
||||
def get_tool_emoji(tool_name: str, default: str = "⚡") -> str:
|
||||
"""Get the display emoji for a tool.
|
||||
|
||||
Resolution order:
|
||||
1. Active skin's ``tool_emojis`` overrides (if a skin is loaded)
|
||||
2. Tool registry's per-tool ``emoji`` field
|
||||
3. *default* fallback
|
||||
"""
|
||||
# 1. Skin override
|
||||
skin = _get_skin()
|
||||
if skin and skin.tool_emojis:
|
||||
override = skin.tool_emojis.get(tool_name)
|
||||
if override:
|
||||
return override
|
||||
# 2. Registry default
|
||||
try:
|
||||
from tools.registry import registry
|
||||
emoji = registry.get_emoji(tool_name, default="")
|
||||
if emoji:
|
||||
return emoji
|
||||
except Exception:
|
||||
pass
|
||||
# 3. Hardcoded fallback
|
||||
return default
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool preview (one-line summary of a tool call's primary argument)
|
||||
# =========================================================================
|
||||
@@ -228,6 +254,15 @@ class KawaiiSpinner:
|
||||
pass
|
||||
|
||||
def _animate(self):
|
||||
# When stdout is not a real terminal (e.g. Docker, systemd, pipe),
|
||||
# skip the animation entirely — it creates massive log bloat.
|
||||
# Just log the start once and let stop() log the completion.
|
||||
if not hasattr(self._out, 'isatty') or not self._out.isatty():
|
||||
self._write(f" [tool] {self.message}", flush=True)
|
||||
while self.running:
|
||||
time.sleep(0.5)
|
||||
return
|
||||
|
||||
# Cache skin wings at start (avoid per-frame imports)
|
||||
skin = _get_skin()
|
||||
wings = skin.get_spinner_wings() if skin else []
|
||||
@@ -293,12 +328,19 @@ class KawaiiSpinner:
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=0.5)
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end='', flush=True)
|
||||
|
||||
is_tty = hasattr(self._out, 'isatty') and self._out.isatty()
|
||||
if is_tty:
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end='', flush=True)
|
||||
if final_message:
|
||||
self._write(f" {final_message}", flush=True)
|
||||
elapsed = f" ({time.time() - self.start_time:.1f}s)" if self.start_time else ""
|
||||
if is_tty:
|
||||
self._write(f" {final_message}", flush=True)
|
||||
else:
|
||||
self._write(f" [done] {final_message}{elapsed}", flush=True)
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
@@ -586,3 +628,95 @@ def write_tty(text: str) -> None:
|
||||
except OSError:
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context pressure display (CLI user-facing warnings)
|
||||
# =========================================================================
|
||||
|
||||
# ANSI color codes for context pressure tiers
|
||||
_CYAN = "\033[36m"
|
||||
_YELLOW = "\033[33m"
|
||||
_BOLD = "\033[1m"
|
||||
_DIM_ANSI = "\033[2m"
|
||||
|
||||
# Bar characters
|
||||
_BAR_FILLED = "▰"
|
||||
_BAR_EMPTY = "▱"
|
||||
_BAR_WIDTH = 20
|
||||
|
||||
|
||||
def format_context_pressure(
|
||||
compaction_progress: float,
|
||||
threshold_tokens: int,
|
||||
threshold_percent: float,
|
||||
compression_enabled: bool = True,
|
||||
) -> str:
|
||||
"""Build a formatted context pressure line for CLI display.
|
||||
|
||||
The bar and percentage show progress toward the compaction threshold,
|
||||
NOT the raw context window. 100% = compaction fires.
|
||||
|
||||
Uses ANSI colors:
|
||||
- cyan at ~60% to compaction = informational
|
||||
- bold yellow at ~85% to compaction = warning
|
||||
|
||||
Args:
|
||||
compaction_progress: How close to compaction (0.0–1.0, 1.0 = fires).
|
||||
threshold_tokens: Compaction threshold in tokens.
|
||||
threshold_percent: Compaction threshold as a fraction of context window.
|
||||
compression_enabled: Whether auto-compression is active.
|
||||
"""
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
threshold_k = f"{threshold_tokens // 1000}k" if threshold_tokens >= 1000 else str(threshold_tokens)
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
# Tier styling
|
||||
if compaction_progress >= 0.85:
|
||||
color = f"{_BOLD}{_YELLOW}"
|
||||
icon = "⚠"
|
||||
if compression_enabled:
|
||||
hint = "compaction imminent"
|
||||
else:
|
||||
hint = "no auto-compaction"
|
||||
else:
|
||||
color = _CYAN
|
||||
icon = "◐"
|
||||
hint = "approaching compaction"
|
||||
|
||||
return (
|
||||
f" {color}{icon} context {bar} {pct_int}% to compaction{_ANSI_RESET}"
|
||||
f" {_DIM_ANSI}{threshold_k} threshold ({threshold_pct_int}%) · {hint}{_ANSI_RESET}"
|
||||
)
|
||||
|
||||
|
||||
def format_context_pressure_gateway(
|
||||
compaction_progress: float,
|
||||
threshold_percent: float,
|
||||
compression_enabled: bool = True,
|
||||
) -> str:
|
||||
"""Build a plain-text context pressure notification for messaging platforms.
|
||||
|
||||
No ANSI — just Unicode and plain text suitable for Telegram/Discord/etc.
|
||||
The percentage shows progress toward the compaction threshold.
|
||||
"""
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
if compaction_progress >= 0.85:
|
||||
icon = "⚠️"
|
||||
if compression_enabled:
|
||||
hint = f"Context compaction is imminent (threshold: {threshold_pct_int}% of window)."
|
||||
else:
|
||||
hint = "Auto-compaction is disabled — context may be truncated."
|
||||
else:
|
||||
icon = "ℹ️"
|
||||
hint = f"Compaction threshold is at {threshold_pct_int}% of context window."
|
||||
|
||||
return f"{icon} Context: {bar} {pct_int}% to compaction\n{hint}"
|
||||
|
||||
+106
-132
@@ -20,65 +20,23 @@ import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
# =========================================================================
|
||||
MODEL_PRICING = {
|
||||
# OpenAI
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
# Anthropic
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
# DeepSeek
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
# Google
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
from agent.usage_pricing import (
|
||||
CanonicalUsage,
|
||||
DEFAULT_PRICING,
|
||||
estimate_usage_cost,
|
||||
format_duration_compact,
|
||||
get_pricing,
|
||||
has_known_pricing,
|
||||
)
|
||||
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
# for self-hosted models, custom OAI endpoints, local inference, etc.)
|
||||
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
_DEFAULT_PRICING = DEFAULT_PRICING
|
||||
|
||||
|
||||
def _has_known_pricing(model_name: str) -> bool:
|
||||
def _has_known_pricing(model_name: str, provider: str = None, base_url: str = None) -> bool:
|
||||
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
return has_known_pricing(model_name, provider=provider, base_url=base_url)
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
@@ -87,67 +45,51 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
if not model_name:
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
|
||||
# Exact match first
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
|
||||
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Keyword heuristics (checked in most-specific-first order)
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return _DEFAULT_PRICING
|
||||
return get_pricing(model_name)
|
||||
|
||||
|
||||
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate the USD cost for a given model and token counts."""
|
||||
pricing = _get_pricing(model)
|
||||
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
|
||||
def _estimate_cost(
|
||||
session_or_model: Dict[str, Any] | str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
provider: str = None,
|
||||
base_url: str = None,
|
||||
) -> tuple[float, str]:
|
||||
"""Estimate the USD cost for a session row or a model/token tuple."""
|
||||
if isinstance(session_or_model, dict):
|
||||
session = session_or_model
|
||||
model = session.get("model") or ""
|
||||
usage = CanonicalUsage(
|
||||
input_tokens=session.get("input_tokens") or 0,
|
||||
output_tokens=session.get("output_tokens") or 0,
|
||||
cache_read_tokens=session.get("cache_read_tokens") or 0,
|
||||
cache_write_tokens=session.get("cache_write_tokens") or 0,
|
||||
)
|
||||
provider = session.get("billing_provider")
|
||||
base_url = session.get("billing_base_url")
|
||||
else:
|
||||
model = session_or_model or ""
|
||||
usage = CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
)
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
usage,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
return float(result.amount_usd or 0.0), result.status
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format seconds into a human-readable duration string."""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
return format_duration_compact(seconds)
|
||||
|
||||
|
||||
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
|
||||
@@ -234,24 +176,30 @@ class InsightsEngine:
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
|
||||
"message_count, tool_call_count, input_tokens, output_tokens")
|
||||
"message_count, tool_call_count, input_tokens, output_tokens, "
|
||||
"cache_read_tokens, cache_write_tokens, billing_provider, "
|
||||
"billing_base_url, billing_mode, estimated_cost_usd, "
|
||||
"actual_cost_usd, cost_status, cost_source")
|
||||
|
||||
# Pre-computed query strings — f-string evaluated once at class definition,
|
||||
# not at runtime, so no user-controlled value can alter the query structure.
|
||||
_GET_SESSIONS_WITH_SOURCE = (
|
||||
f"SELECT {_SESSION_COLS} FROM sessions"
|
||||
" WHERE started_at >= ? AND source = ?"
|
||||
" ORDER BY started_at DESC"
|
||||
)
|
||||
_GET_SESSIONS_ALL = (
|
||||
f"SELECT {_SESSION_COLS} FROM sessions"
|
||||
" WHERE started_at >= ?"
|
||||
" ORDER BY started_at DESC"
|
||||
)
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ? AND source = ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
cursor = self._conn.execute(self._GET_SESSIONS_WITH_SOURCE, (cutoff, source))
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
cursor = self._conn.execute(self._GET_SESSIONS_ALL, (cutoff,))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
@@ -386,21 +334,30 @@ class InsightsEngine:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
total_tokens = total_input + total_output
|
||||
total_cache_read = sum(s.get("cache_read_tokens") or 0 for s in sessions)
|
||||
total_cache_write = sum(s.get("cache_write_tokens") or 0 for s in sessions)
|
||||
total_tokens = total_input + total_output + total_cache_read + total_cache_write
|
||||
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
|
||||
total_messages = sum(s.get("message_count") or 0 for s in sessions)
|
||||
|
||||
# Cost estimation (weighted by model)
|
||||
total_cost = 0.0
|
||||
actual_cost = 0.0
|
||||
models_with_pricing = set()
|
||||
models_without_pricing = set()
|
||||
unknown_cost_sessions = 0
|
||||
included_cost_sessions = 0
|
||||
for s in sessions:
|
||||
model = s.get("model") or ""
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
total_cost += _estimate_cost(model, inp, out)
|
||||
estimated, status = _estimate_cost(s)
|
||||
total_cost += estimated
|
||||
actual_cost += s.get("actual_cost_usd") or 0.0
|
||||
display = model.split("/")[-1] if "/" in model else (model or "unknown")
|
||||
if _has_known_pricing(model):
|
||||
if status == "included":
|
||||
included_cost_sessions += 1
|
||||
elif status == "unknown":
|
||||
unknown_cost_sessions += 1
|
||||
if _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")):
|
||||
models_with_pricing.add(display)
|
||||
else:
|
||||
models_without_pricing.add(display)
|
||||
@@ -427,8 +384,11 @@ class InsightsEngine:
|
||||
"total_tool_calls": total_tool_calls,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_cache_read_tokens": total_cache_read,
|
||||
"total_cache_write_tokens": total_cache_write,
|
||||
"total_tokens": total_tokens,
|
||||
"estimated_cost": total_cost,
|
||||
"actual_cost": actual_cost,
|
||||
"total_hours": total_hours,
|
||||
"avg_session_duration": avg_duration,
|
||||
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
|
||||
@@ -440,12 +400,15 @@ class InsightsEngine:
|
||||
"date_range_end": date_range_end,
|
||||
"models_with_pricing": sorted(models_with_pricing),
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
"unknown_cost_sessions": unknown_cost_sessions,
|
||||
"included_cost_sessions": included_cost_sessions,
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(lambda: {
|
||||
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
|
||||
"cache_read_tokens": 0, "cache_write_tokens": 0,
|
||||
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
|
||||
})
|
||||
|
||||
@@ -457,12 +420,18 @@ class InsightsEngine:
|
||||
d["sessions"] += 1
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
cache_read = s.get("cache_read_tokens") or 0
|
||||
cache_write = s.get("cache_write_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["cache_read_tokens"] += cache_read
|
||||
d["cache_write_tokens"] += cache_write
|
||||
d["total_tokens"] += inp + out + cache_read + cache_write
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
estimate, status = _estimate_cost(s)
|
||||
d["cost"] += estimate
|
||||
d["has_pricing"] = _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url"))
|
||||
d["cost_status"] = status
|
||||
|
||||
result = [
|
||||
{"model": model, **data}
|
||||
@@ -476,7 +445,8 @@ class InsightsEngine:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(lambda: {
|
||||
"sessions": 0, "messages": 0, "input_tokens": 0,
|
||||
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
"output_tokens": 0, "cache_read_tokens": 0,
|
||||
"cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
@@ -486,9 +456,13 @@ class InsightsEngine:
|
||||
d["messages"] += s.get("message_count") or 0
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
cache_read = s.get("cache_read_tokens") or 0
|
||||
cache_write = s.get("cache_write_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["cache_read_tokens"] += cache_read
|
||||
d["cache_write_tokens"] += cache_write
|
||||
d["total_tokens"] += inp + out + cache_read + cache_write
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [
|
||||
|
||||
+717
-55
@@ -10,6 +10,7 @@ import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -18,61 +19,346 @@ from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider names that can appear as a "provider:" prefix before a model ID.
|
||||
# Only these are stripped — Ollama-style "model:tag" colons (e.g. "qwen3.5:27b")
|
||||
# are preserved so the full model name reaches cache lookups and server queries.
|
||||
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"custom", "local",
|
||||
# Common aliases
|
||||
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
||||
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
})
|
||||
|
||||
|
||||
_OLLAMA_TAG_PATTERN = re.compile(
|
||||
r"^(\d+\.?\d*b|latest|stable|q\d|fp?\d|instruct|chat|coder|vision|text)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _strip_provider_prefix(model: str) -> str:
|
||||
"""Strip a recognised provider prefix from a model string.
|
||||
|
||||
``"local:my-model"`` → ``"my-model"``
|
||||
``"qwen3.5:27b"`` → ``"qwen3.5:27b"`` (unchanged — not a provider prefix)
|
||||
``"qwen:0.5b"`` → ``"qwen:0.5b"`` (unchanged — Ollama model:tag)
|
||||
``"deepseek:latest"``→ ``"deepseek:latest"``(unchanged — Ollama model:tag)
|
||||
"""
|
||||
if ":" not in model or model.startswith("http"):
|
||||
return model
|
||||
prefix, suffix = model.split(":", 1)
|
||||
prefix_lower = prefix.strip().lower()
|
||||
if prefix_lower in _PROVIDER_PREFIXES:
|
||||
# Don't strip if suffix looks like an Ollama tag (e.g. "7b", "latest", "q4_0")
|
||||
if _OLLAMA_TAG_PATTERN.match(suffix.strip()):
|
||||
return model
|
||||
return suffix
|
||||
return model
|
||||
|
||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
||||
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||
_ENDPOINT_MODEL_CACHE_TTL = 300
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start high and step down on context-length errors until one works.
|
||||
# We start at 128K (a safe default for most modern models) and step down
|
||||
# on context-length errors until one works.
|
||||
CONTEXT_PROBE_TIERS = [
|
||||
2_000_000,
|
||||
1_000_000,
|
||||
512_000,
|
||||
200_000,
|
||||
128_000,
|
||||
64_000,
|
||||
32_000,
|
||||
16_000,
|
||||
8_000,
|
||||
]
|
||||
|
||||
# Default context length when no detection method succeeds.
|
||||
DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
# Thin fallback defaults — only broad model family patterns.
|
||||
# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic
|
||||
# all miss. Replaced the previous 80+ entry dict.
|
||||
# For provider-specific context lengths, models.dev is the primary source.
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
# Bare Anthropic model IDs (for native API provider)
|
||||
"claude-opus-4-6": 200000,
|
||||
"claude-sonnet-4-6": 200000,
|
||||
"claude-opus-4-5-20251101": 200000,
|
||||
"claude-sonnet-4-5-20250929": 200000,
|
||||
"claude-opus-4-1-20250805": 200000,
|
||||
"claude-opus-4-20250514": 200000,
|
||||
"claude-sonnet-4-20250514": 200000,
|
||||
"claude-haiku-4-5-20251001": 200000,
|
||||
"openai/gpt-4o": 128000,
|
||||
"openai/gpt-4-turbo": 128000,
|
||||
"openai/gpt-4o-mini": 128000,
|
||||
"google/gemini-2.0-flash": 1048576,
|
||||
"google/gemini-2.5-pro": 1048576,
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-for-coding": 262144,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-thinking-turbo": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
# Anthropic Claude 4.6 (1M context) — bare IDs only to avoid
|
||||
# fuzzy-match collisions (e.g. "anthropic/claude-sonnet-4" is a
|
||||
# substring of "anthropic/claude-sonnet-4.6").
|
||||
# OpenRouter-prefixed models resolve via OpenRouter live API or models.dev.
|
||||
"claude-opus-4-6": 1000000,
|
||||
"claude-sonnet-4-6": 1000000,
|
||||
"claude-opus-4.6": 1000000,
|
||||
"claude-sonnet-4.6": 1000000,
|
||||
# Catch-all for older Claude models (must sort after specific entries)
|
||||
"claude": 200000,
|
||||
# OpenAI
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-5": 128000,
|
||||
"gpt-4": 128000,
|
||||
# Google
|
||||
"gemini": 1048576,
|
||||
# DeepSeek
|
||||
"deepseek": 128000,
|
||||
# Meta
|
||||
"llama": 131072,
|
||||
# Qwen
|
||||
"qwen": 131072,
|
||||
# MiniMax
|
||||
"minimax": 204800,
|
||||
# GLM
|
||||
"glm": 202752,
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
}
|
||||
|
||||
_CONTEXT_LENGTH_KEYS = (
|
||||
"context_length",
|
||||
"context_window",
|
||||
"max_context_length",
|
||||
"max_position_embeddings",
|
||||
"max_model_len",
|
||||
"max_input_tokens",
|
||||
"max_sequence_length",
|
||||
"max_seq_len",
|
||||
"n_ctx_train",
|
||||
"n_ctx",
|
||||
)
|
||||
|
||||
_MAX_COMPLETION_KEYS = (
|
||||
"max_completion_tokens",
|
||||
"max_output_tokens",
|
||||
"max_tokens",
|
||||
)
|
||||
|
||||
# Local server hostnames / address patterns
|
||||
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
return (base_url or "").strip().rstrip("/")
|
||||
|
||||
|
||||
def _is_openrouter_base_url(base_url: str) -> bool:
|
||||
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
||||
|
||||
|
||||
def _is_custom_endpoint(base_url: str) -> bool:
|
||||
normalized = _normalize_base_url(base_url)
|
||||
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
||||
|
||||
|
||||
_URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.openai.com": "openai",
|
||||
"chatgpt.com": "openai",
|
||||
"api.anthropic.com": "anthropic",
|
||||
"api.z.ai": "zai",
|
||||
"api.moonshot.ai": "kimi-coding",
|
||||
"api.kimi.com": "kimi-coding",
|
||||
"api.minimax": "minimax",
|
||||
"dashscope.aliyuncs.com": "alibaba",
|
||||
"dashscope-intl.aliyuncs.com": "alibaba",
|
||||
"openrouter.ai": "openrouter",
|
||||
"inference-api.nousresearch.com": "nous",
|
||||
"api.deepseek.com": "deepseek",
|
||||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
}
|
||||
|
||||
|
||||
def _infer_provider_from_url(base_url: str) -> Optional[str]:
|
||||
"""Infer the models.dev provider name from a base URL.
|
||||
|
||||
This allows context length resolution via models.dev for custom endpoints
|
||||
like DashScope (Alibaba), Z.AI, Kimi, etc. without requiring the user to
|
||||
explicitly set the provider name in config.
|
||||
"""
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized:
|
||||
return None
|
||||
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
||||
host = parsed.netloc.lower() or parsed.path.lower()
|
||||
for url_part, provider in _URL_TO_PROVIDER.items():
|
||||
if url_part in host:
|
||||
return provider
|
||||
return None
|
||||
|
||||
|
||||
def _is_known_provider_base_url(base_url: str) -> bool:
|
||||
return _infer_provider_from_url(base_url) is not None
|
||||
|
||||
|
||||
def is_local_endpoint(base_url: str) -> bool:
|
||||
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized:
|
||||
return False
|
||||
url = normalized if "://" in normalized else f"http://{normalized}"
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
except Exception:
|
||||
return False
|
||||
if host in _LOCAL_HOSTS:
|
||||
return True
|
||||
# RFC-1918 private ranges and link-local
|
||||
import ipaddress
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
return addr.is_private or addr.is_loopback or addr.is_link_local
|
||||
except ValueError:
|
||||
pass
|
||||
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
||||
parts = host.split(".")
|
||||
if len(parts) == 4:
|
||||
try:
|
||||
first, second = int(parts[0]), int(parts[1])
|
||||
if first == 10:
|
||||
return True
|
||||
if first == 172 and 16 <= second <= 31:
|
||||
return True
|
||||
if first == 192 and second == 168:
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
"""Detect which local server is running at base_url by probing known endpoints.
|
||||
|
||||
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
normalized = _normalize_base_url(base_url)
|
||||
server_url = normalized
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=2.0) as client:
|
||||
# LM Studio exposes /api/v1/models — check first (most specific)
|
||||
try:
|
||||
r = client.get(f"{server_url}/api/v1/models")
|
||||
if r.status_code == 200:
|
||||
return "lm-studio"
|
||||
except Exception:
|
||||
pass
|
||||
# Ollama exposes /api/tags and responds with {"models": [...]}
|
||||
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
||||
# on this path, so we must verify the response contains "models".
|
||||
try:
|
||||
r = client.get(f"{server_url}/api/tags")
|
||||
if r.status_code == 200:
|
||||
try:
|
||||
data = r.json()
|
||||
if "models" in data:
|
||||
return "ollama"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# llama.cpp exposes /v1/props (older builds used /props without the /v1 prefix)
|
||||
try:
|
||||
r = client.get(f"{server_url}/v1/props")
|
||||
if r.status_code != 200:
|
||||
r = client.get(f"{server_url}/props") # fallback for older builds
|
||||
if r.status_code == 200 and "default_generation_settings" in r.text:
|
||||
return "llamacpp"
|
||||
except Exception:
|
||||
pass
|
||||
# vLLM: /version
|
||||
try:
|
||||
r = client.get(f"{server_url}/version")
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
if "version" in data:
|
||||
return "vllm"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _iter_nested_dicts(value: Any):
|
||||
if isinstance(value, dict):
|
||||
yield value
|
||||
for nested in value.values():
|
||||
yield from _iter_nested_dicts(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
yield from _iter_nested_dicts(item)
|
||||
|
||||
|
||||
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
||||
try:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
value = value.strip().replace(",", "")
|
||||
result = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if minimum <= result <= maximum:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
||||
keyset = {key.lower() for key in keys}
|
||||
for mapping in _iter_nested_dicts(payload):
|
||||
for key, value in mapping.items():
|
||||
if str(key).lower() not in keyset:
|
||||
continue
|
||||
coerced = _coerce_reasonable_int(value)
|
||||
if coerced is not None:
|
||||
return coerced
|
||||
return None
|
||||
|
||||
|
||||
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
||||
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
||||
|
||||
|
||||
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
||||
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
||||
|
||||
|
||||
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
alias_map = {
|
||||
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
||||
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
||||
"request": ("request", "request_cost"),
|
||||
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
||||
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
||||
}
|
||||
for mapping in _iter_nested_dicts(payload):
|
||||
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
||||
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
||||
continue
|
||||
pricing: Dict[str, Any] = {}
|
||||
for target, aliases in alias_map.items():
|
||||
for alias in aliases:
|
||||
if alias in normalized and normalized[alias] not in (None, ""):
|
||||
pricing[target] = normalized[alias]
|
||||
break
|
||||
if pricing:
|
||||
return pricing
|
||||
return {}
|
||||
|
||||
|
||||
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
||||
cache[model_id] = entry
|
||||
if "/" in model_id:
|
||||
bare_model = model_id.split("/", 1)[1]
|
||||
cache.setdefault(bare_model, entry)
|
||||
|
||||
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||
@@ -89,15 +375,16 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
||||
cache = {}
|
||||
for model in data.get("data", []):
|
||||
model_id = model.get("id", "")
|
||||
cache[model_id] = {
|
||||
entry = {
|
||||
"context_length": model.get("context_length", 128000),
|
||||
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
||||
"name": model.get("name", model_id),
|
||||
"pricing": model.get("pricing", {}),
|
||||
}
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
canonical = model.get("canonical_slug", "")
|
||||
if canonical and canonical != model_id:
|
||||
cache[canonical] = cache[model_id]
|
||||
_add_model_aliases(cache, canonical, entry)
|
||||
|
||||
_model_metadata_cache = cache
|
||||
_model_metadata_cache_time = time.time()
|
||||
@@ -109,6 +396,97 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
||||
return _model_metadata_cache or {}
|
||||
|
||||
|
||||
def fetch_endpoint_model_metadata(
|
||||
base_url: str,
|
||||
api_key: str = "",
|
||||
force_refresh: bool = False,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
||||
|
||||
This is used for explicit custom endpoints where hardcoded global model-name
|
||||
defaults are unreliable. Results are cached in memory per base URL.
|
||||
"""
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized or _is_openrouter_base_url(normalized):
|
||||
return {}
|
||||
|
||||
if not force_refresh:
|
||||
cached = _endpoint_model_metadata_cache.get(normalized)
|
||||
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
||||
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
||||
return cached
|
||||
|
||||
candidates = [normalized]
|
||||
if normalized.endswith("/v1"):
|
||||
alternate = normalized[:-3].rstrip("/")
|
||||
else:
|
||||
alternate = normalized + "/v1"
|
||||
if alternate and alternate not in candidates:
|
||||
candidates.append(alternate)
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for candidate in candidates:
|
||||
url = candidate.rstrip("/") + "/models"
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
cache: Dict[str, Dict[str, Any]] = {}
|
||||
for model in payload.get("data", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_id = model.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
||||
context_length = _extract_context_length(model)
|
||||
if context_length is not None:
|
||||
entry["context_length"] = context_length
|
||||
max_completion_tokens = _extract_max_completion_tokens(model)
|
||||
if max_completion_tokens is not None:
|
||||
entry["max_completion_tokens"] = max_completion_tokens
|
||||
pricing = _extract_pricing(model)
|
||||
if pricing:
|
||||
entry["pricing"] = pricing
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
|
||||
# If this is a llama.cpp server, query /props for actual allocated context
|
||||
is_llamacpp = any(
|
||||
m.get("owned_by") == "llamacpp"
|
||||
for m in payload.get("data", []) if isinstance(m, dict)
|
||||
)
|
||||
if is_llamacpp:
|
||||
try:
|
||||
# Try /v1/props first (current llama.cpp); fall back to /props for older builds
|
||||
base = candidate.rstrip("/").replace("/v1", "")
|
||||
props_resp = requests.get(base + "/v1/props", headers=headers, timeout=5)
|
||||
if not props_resp.ok:
|
||||
props_resp = requests.get(base + "/props", headers=headers, timeout=5)
|
||||
if props_resp.ok:
|
||||
props = props_resp.json()
|
||||
gen_settings = props.get("default_generation_settings", {})
|
||||
n_ctx = gen_settings.get("n_ctx")
|
||||
model_alias = props.get("model_alias", "")
|
||||
if n_ctx and model_alias and model_alias in cache:
|
||||
cache[model_alias]["context_length"] = n_ctx
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_endpoint_model_metadata_cache[normalized] = cache
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return cache
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
if last_error:
|
||||
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
||||
_endpoint_model_metadata_cache[normalized] = {}
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return {}
|
||||
|
||||
|
||||
def _get_context_cache_path() -> Path:
|
||||
"""Return path to the persistent context length cache file."""
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
@@ -116,7 +494,7 @@ def _get_context_cache_path() -> Path:
|
||||
|
||||
|
||||
def _load_context_cache() -> Dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
"""Load the model+provider -> context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
return {}
|
||||
@@ -145,7 +523,7 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
||||
logger.info("Cached context length %s → %s tokens", key, f"{length:,}")
|
||||
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
@@ -193,33 +571,317 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
||||
|
||||
Supports two forms:
|
||||
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
||||
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
||||
(the part after the last "/" equals lookup_model)
|
||||
|
||||
This covers LM Studio's native API which stores models as "publisher/slug"
|
||||
while users typically configure only the slug after the "local:" prefix.
|
||||
"""
|
||||
if candidate_id == lookup_model:
|
||||
return True
|
||||
# Slug match: basename of candidate equals the lookup name
|
||||
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query a local server for the model's context length."""
|
||||
import httpx
|
||||
|
||||
# Strip recognised provider prefix (e.g., "local:model-name" → "model-name").
|
||||
# Ollama "model:tag" colons (e.g. "qwen3.5:27b") are intentionally preserved.
|
||||
model = _strip_provider_prefix(model)
|
||||
|
||||
# Strip /v1 suffix to get the server root
|
||||
server_url = base_url.rstrip("/")
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
except Exception:
|
||||
server_type = None
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
# Ollama: /api/show returns model details with context info
|
||||
if server_type == "ollama":
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Check model_info for context length
|
||||
model_info = data.get("model_info", {})
|
||||
for key, value in model_info.items():
|
||||
if "context_length" in key and isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
# Check parameters string for num_ctx
|
||||
params = data.get("parameters", "")
|
||||
if "num_ctx" in params:
|
||||
for line in params.split("\n"):
|
||||
if "num_ctx" in line:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
return int(parts[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# LM Studio native API: /api/v1/models returns max_context_length.
|
||||
# This is more reliable than the OpenAI-compat /v1/models which
|
||||
# doesn't include context window information for LM Studio servers.
|
||||
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
||||
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
||||
if server_type == "lm-studio":
|
||||
resp = client.get(f"{server_url}/api/v1/models")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
for m in data.get("models", []):
|
||||
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
||||
# Prefer loaded instance context (actual runtime value)
|
||||
for inst in m.get("loaded_instances", []):
|
||||
cfg = inst.get("config", {})
|
||||
ctx = cfg.get("context_length")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
# Fall back to max_context_length (theoretical model max)
|
||||
ctx = m.get("max_context_length") or m.get("context_length")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
|
||||
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
||||
resp = client.get(f"{server_url}/v1/models/{model}")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# vLLM returns max_model_len
|
||||
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
|
||||
# Try /v1/models and find the model in the list.
|
||||
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
||||
resp = client.get(f"{server_url}/v1/models")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
models_list = data.get("data", [])
|
||||
for m in models_list:
|
||||
if _model_id_matches(m.get("id", ""), model):
|
||||
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_model_version(model: str) -> str:
|
||||
"""Normalize version separators for matching.
|
||||
|
||||
Nous uses dashes: claude-opus-4-6, claude-sonnet-4-5
|
||||
OpenRouter uses dots: claude-opus-4.6, claude-sonnet-4.5
|
||||
Normalize both to dashes for comparison.
|
||||
"""
|
||||
return model.replace(".", "-")
|
||||
|
||||
|
||||
def _query_anthropic_context_length(model: str, base_url: str, api_key: str) -> Optional[int]:
|
||||
"""Query Anthropic's /v1/models endpoint for context length.
|
||||
|
||||
Only works with regular ANTHROPIC_API_KEY (sk-ant-api*).
|
||||
OAuth tokens (sk-ant-oat*) from Claude Code return 401.
|
||||
"""
|
||||
if not api_key or api_key.startswith("sk-ant-oat"):
|
||||
return None # OAuth tokens can't access /v1/models
|
||||
try:
|
||||
base = base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
url = f"{base}/v1/models?limit=1000"
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
resp = requests.get(url, headers=headers, timeout=10)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
data = resp.json()
|
||||
for m in data.get("data", []):
|
||||
if m.get("id") == model:
|
||||
ctx = m.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
return ctx
|
||||
except Exception as e:
|
||||
logger.debug("Anthropic /v1/models query failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_nous_context_length(model: str) -> Optional[int]:
|
||||
"""Resolve Nous Portal model context length via OpenRouter metadata.
|
||||
|
||||
Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses
|
||||
prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching
|
||||
with version normalization (dot↔dash).
|
||||
"""
|
||||
metadata = fetch_model_metadata() # OpenRouter cache
|
||||
# Exact match first
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length")
|
||||
|
||||
normalized = _normalize_model_version(model).lower()
|
||||
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized:
|
||||
return entry.get("context_length")
|
||||
|
||||
# Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview
|
||||
# Require match to be at a word boundary (followed by -, :, or end of string)
|
||||
model_lower = model.lower()
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
for candidate, query in [(bare.lower(), model_lower), (_normalize_model_version(bare).lower(), normalized)]:
|
||||
if candidate.startswith(query) and (
|
||||
len(candidate) == len(query) or candidate[len(query)] in "-:."
|
||||
):
|
||||
return entry.get("context_length")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(
|
||||
model: str,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
config_context_length: int | None = None,
|
||||
provider: str = "",
|
||||
) -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
Resolution order:
|
||||
0. Explicit config override (model.context_length or custom_providers per-model)
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
2. OpenRouter API metadata
|
||||
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
|
||||
4. First probe tier (2M) — will be narrowed on first context error
|
||||
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||
3. Local server query (for local endpoints)
|
||||
4. Anthropic /v1/models API (API-key users only, not OAuth)
|
||||
5. OpenRouter live API metadata
|
||||
6. Nous suffix-match via OpenRouter cache
|
||||
7. models.dev registry lookup (provider-aware)
|
||||
8. Thin hardcoded defaults (broad family patterns)
|
||||
9. Default fallback (128K)
|
||||
"""
|
||||
# 0. Explicit config override — user knows best
|
||||
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
||||
return config_context_length
|
||||
|
||||
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
||||
# "model-name") so cache lookups and server queries use the bare ID that
|
||||
# local servers actually know about. Ollama "model:tag" colons are preserved.
|
||||
model = _strip_provider_prefix(model)
|
||||
|
||||
# 1. Check persistent cache (model+provider)
|
||||
if base_url:
|
||||
cached = get_cached_context_length(model, base_url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. OpenRouter API metadata
|
||||
# 2. Active endpoint metadata for truly custom/unknown endpoints.
|
||||
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
|
||||
# /models endpoint may report a provider-imposed limit (e.g. Copilot
|
||||
# returns 128k) instead of the model's full context (400k). models.dev
|
||||
# has the correct per-provider values and is checked at step 5+.
|
||||
if _is_custom_endpoint(base_url) and not _is_known_provider_base_url(base_url):
|
||||
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||
matched = endpoint_metadata.get(model)
|
||||
if not matched:
|
||||
# Single-model servers: if only one model is loaded, use it
|
||||
if len(endpoint_metadata) == 1:
|
||||
matched = next(iter(endpoint_metadata.values()))
|
||||
else:
|
||||
# Fuzzy match: substring in either direction
|
||||
for key, entry in endpoint_metadata.items():
|
||||
if model in key or key in model:
|
||||
matched = entry
|
||||
break
|
||||
if matched:
|
||||
context_length = matched.get("context_length")
|
||||
if isinstance(context_length, int):
|
||||
return context_length
|
||||
if not _is_known_provider_base_url(base_url):
|
||||
# 3. Try querying local server directly
|
||||
if is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
logger.info(
|
||||
"Could not detect context length for model %r at %s — "
|
||||
"defaulting to %s tokens (probe-down). Set model.context_length "
|
||||
"in config.yaml to override.",
|
||||
model, base_url, f"{DEFAULT_FALLBACK_CONTEXT:,}",
|
||||
)
|
||||
return DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
# 4. Anthropic /v1/models API (only for regular API keys, not OAuth)
|
||||
if provider == "anthropic" or (
|
||||
base_url and "api.anthropic.com" in base_url
|
||||
):
|
||||
ctx = _query_anthropic_context_length(model, base_url or "https://api.anthropic.com", api_key)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# 5. Provider-aware lookups (before generic OpenRouter cache)
|
||||
# These are provider-specific and take priority over the generic OR cache,
|
||||
# since the same model can have different context limits per provider
|
||||
# (e.g. claude-opus-4.6 is 1M on Anthropic but 128K on GitHub Copilot).
|
||||
# If provider is generic (openrouter/custom/empty), try to infer from URL.
|
||||
effective_provider = provider
|
||||
if not effective_provider or effective_provider in ("openrouter", "custom"):
|
||||
if base_url:
|
||||
inferred = _infer_provider_from_url(base_url)
|
||||
if inferred:
|
||||
effective_provider = inferred
|
||||
|
||||
if effective_provider == "nous":
|
||||
ctx = _resolve_nous_context_length(model)
|
||||
if ctx:
|
||||
return ctx
|
||||
if effective_provider:
|
||||
from agent.models_dev import lookup_models_dev_context
|
||||
ctx = lookup_models_dev_context(effective_provider, model)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# 6. OpenRouter live API metadata (provider-unaware fallback)
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match)
|
||||
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if default_model in model or model in default_model:
|
||||
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
# Only check `default_model in model` (is the key a substring of the input).
|
||||
# The reverse (`model in default_model`) causes shorter names like
|
||||
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
|
||||
model_lower = model.lower()
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model_lower:
|
||||
return length
|
||||
|
||||
# 4. Unknown model — start at highest probe tier
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
# 9. Query local server as last resort
|
||||
if base_url and is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
|
||||
# 10. Default fallback — 128K
|
||||
return DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Models.dev registry integration for provider-aware context length detection.
|
||||
|
||||
Fetches model metadata from https://models.dev/api.json — a community-maintained
|
||||
database of 3800+ models across 100+ providers, including per-provider context
|
||||
windows, pricing, and capabilities.
|
||||
|
||||
Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json)
|
||||
to avoid cold-start network latency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_DEV_URL = "https://models.dev/api.json"
|
||||
_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory
|
||||
|
||||
# In-memory cache
|
||||
_models_dev_cache: Dict[str, Any] = {}
|
||||
_models_dev_cache_time: float = 0
|
||||
|
||||
# Provider ID mapping: Hermes provider names → models.dev provider IDs
|
||||
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
||||
"openrouter": "openrouter",
|
||||
"anthropic": "anthropic",
|
||||
"zai": "zai",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
"minimax": "minimax",
|
||||
"minimax-cn": "minimax-cn",
|
||||
"deepseek": "deepseek",
|
||||
"alibaba": "alibaba",
|
||||
"copilot": "github-copilot",
|
||||
"ai-gateway": "vercel",
|
||||
"opencode-zen": "opencode",
|
||||
"opencode-go": "opencode-go",
|
||||
"kilocode": "kilo",
|
||||
}
|
||||
|
||||
|
||||
def _get_cache_path() -> Path:
|
||||
"""Return path to disk cache file."""
|
||||
env_val = os.environ.get("HERMES_HOME", "")
|
||||
hermes_home = Path(env_val) if env_val else Path.home() / ".hermes"
|
||||
return hermes_home / "models_dev_cache.json"
|
||||
|
||||
|
||||
def _load_disk_cache() -> Dict[str, Any]:
|
||||
"""Load models.dev data from disk cache."""
|
||||
try:
|
||||
cache_path = _get_cache_path()
|
||||
if cache_path.exists():
|
||||
with open(cache_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load models.dev disk cache: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _save_disk_cache(data: Dict[str, Any]) -> None:
|
||||
"""Save models.dev data to disk cache."""
|
||||
try:
|
||||
cache_path = _get_cache_path()
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, separators=(",", ":"))
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save models.dev disk cache: %s", e)
|
||||
|
||||
|
||||
def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""Fetch models.dev registry. In-memory cache (1hr) + disk fallback.
|
||||
|
||||
Returns the full registry dict keyed by provider ID, or empty dict on failure.
|
||||
"""
|
||||
global _models_dev_cache, _models_dev_cache_time
|
||||
|
||||
# Check in-memory cache
|
||||
if (
|
||||
not force_refresh
|
||||
and _models_dev_cache
|
||||
and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL
|
||||
):
|
||||
return _models_dev_cache
|
||||
|
||||
# Try network fetch
|
||||
try:
|
||||
response = requests.get(MODELS_DEV_URL, timeout=15)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if isinstance(data, dict) and len(data) > 0:
|
||||
_models_dev_cache = data
|
||||
_models_dev_cache_time = time.time()
|
||||
_save_disk_cache(data)
|
||||
logger.debug(
|
||||
"Fetched models.dev registry: %d providers, %d total models",
|
||||
len(data),
|
||||
sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)),
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch models.dev: %s", e)
|
||||
|
||||
# Fall back to disk cache — use a short TTL (5 min) so we retry
|
||||
# the network fetch soon instead of serving stale data for a full hour.
|
||||
if not _models_dev_cache:
|
||||
_models_dev_cache = _load_disk_cache()
|
||||
if _models_dev_cache:
|
||||
_models_dev_cache_time = time.time() - _MODELS_DEV_CACHE_TTL + 300
|
||||
logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache))
|
||||
|
||||
return _models_dev_cache
|
||||
|
||||
|
||||
def lookup_models_dev_context(provider: str, model: str) -> Optional[int]:
|
||||
"""Look up context_length for a provider+model combo in models.dev.
|
||||
|
||||
Returns the context window in tokens, or None if not found.
|
||||
Handles case-insensitive matching and filters out context=0 entries.
|
||||
"""
|
||||
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
||||
if not mdev_provider_id:
|
||||
return None
|
||||
|
||||
data = fetch_models_dev()
|
||||
provider_data = data.get(mdev_provider_id)
|
||||
if not isinstance(provider_data, dict):
|
||||
return None
|
||||
|
||||
models = provider_data.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
return None
|
||||
|
||||
# Exact match
|
||||
entry = models.get(model)
|
||||
if entry:
|
||||
ctx = _extract_context(entry)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# Case-insensitive match
|
||||
model_lower = model.lower()
|
||||
for mid, mdata in models.items():
|
||||
if mid.lower() == model_lower:
|
||||
ctx = _extract_context(mdata)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_context(entry: Dict[str, Any]) -> Optional[int]:
|
||||
"""Extract context_length from a models.dev model entry.
|
||||
|
||||
Returns None for invalid/zero values (some audio/image models have context=0).
|
||||
"""
|
||||
if not isinstance(entry, dict):
|
||||
return None
|
||||
limit = entry.get("limit")
|
||||
if not isinstance(limit, dict):
|
||||
return None
|
||||
ctx = limit.get("context")
|
||||
if isinstance(ctx, (int, float)) and ctx > 0:
|
||||
return int(ctx)
|
||||
return None
|
||||
+225
-67
@@ -56,6 +56,61 @@ def _scan_context_content(content: str, filename: str) -> str:
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _find_git_root(start: Path) -> Optional[Path]:
|
||||
"""Walk *start* and its parents looking for a ``.git`` directory.
|
||||
|
||||
Returns the directory containing ``.git``, or ``None`` if we hit the
|
||||
filesystem root without finding one.
|
||||
"""
|
||||
current = start.resolve()
|
||||
for parent in [current, *current.parents]:
|
||||
if (parent / ".git").exists():
|
||||
return parent
|
||||
return None
|
||||
|
||||
|
||||
_HERMES_MD_NAMES = (".hermes.md", "HERMES.md")
|
||||
|
||||
|
||||
def _find_hermes_md(cwd: Path) -> Optional[Path]:
|
||||
"""Discover the nearest ``.hermes.md`` or ``HERMES.md``.
|
||||
|
||||
Search order: *cwd* first, then each parent directory up to (and
|
||||
including) the git repository root. Returns the first match, or
|
||||
``None`` if nothing is found.
|
||||
"""
|
||||
stop_at = _find_git_root(cwd)
|
||||
current = cwd.resolve()
|
||||
|
||||
for directory in [current, *current.parents]:
|
||||
for name in _HERMES_MD_NAMES:
|
||||
candidate = directory / name
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
# Stop walking at the git root (or filesystem root).
|
||||
if stop_at and directory == stop_at:
|
||||
break
|
||||
return None
|
||||
|
||||
|
||||
def _strip_yaml_frontmatter(content: str) -> str:
|
||||
"""Remove optional YAML frontmatter (``---`` delimited) from *content*.
|
||||
|
||||
The frontmatter may contain structured config (model overrides, tool
|
||||
settings) that will be handled separately in a future PR. For now we
|
||||
strip it so only the human-readable markdown body is injected into the
|
||||
system prompt.
|
||||
"""
|
||||
if content.startswith("---"):
|
||||
end = content.find("\n---", 3)
|
||||
if end != -1:
|
||||
# Skip past the closing --- and any trailing newline
|
||||
body = content[end + 4:].lstrip("\n")
|
||||
return body if body else content
|
||||
return content
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Constants
|
||||
# =========================================================================
|
||||
@@ -73,9 +128,15 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
MEMORY_GUIDANCE = (
|
||||
"You have persistent memory across sessions. Save durable facts using the memory "
|
||||
"tool: user preferences, environment details, tool quirks, and stable conventions. "
|
||||
"Memory is injected into every turn, so keep it compact. Do NOT save task progress, "
|
||||
"session outcomes, or completed-work logs to memory; use session_search to recall "
|
||||
"those from past transcripts."
|
||||
"Memory is injected into every turn, so keep it compact and focused on facts that "
|
||||
"will still matter later.\n"
|
||||
"Prioritize what reduces future user steering — the most valuable memory is one "
|
||||
"that prevents the user from having to correct or remind you again. "
|
||||
"User preferences and recurring corrections matter more than procedural task details.\n"
|
||||
"Do NOT save task progress, session outcomes, completed-work logs, or temporary TODO "
|
||||
"state to memory; use session_search to recall those from past transcripts. "
|
||||
"If you've discovered a new way to do something, solved a problem that could be "
|
||||
"necessary later, save it as a skill with the skill tool."
|
||||
)
|
||||
|
||||
SESSION_SEARCH_GUIDANCE = (
|
||||
@@ -86,8 +147,11 @@ SESSION_SEARCH_GUIDANCE = (
|
||||
|
||||
SKILLS_GUIDANCE = (
|
||||
"After completing a complex task (5+ tool calls), fixing a tricky error, "
|
||||
"or discovering a non-trivial workflow, consider saving the approach as a "
|
||||
"skill with skill_manage so you can reuse it next time."
|
||||
"or discovering a non-trivial workflow, save the approach as a "
|
||||
"skill with skill_manage so you can reuse it next time.\n"
|
||||
"When using a skill and finding it outdated, incomplete, or wrong, "
|
||||
"patch it immediately with skill_manage(action='patch') — don't wait to be asked. "
|
||||
"Skills that aren't maintained become liabilities."
|
||||
)
|
||||
|
||||
PLATFORM_HINTS = {
|
||||
@@ -142,16 +206,21 @@ PLATFORM_HINTS = {
|
||||
"contextually appropriate."
|
||||
),
|
||||
"cron": (
|
||||
"You are running as a scheduled cron job. Your final response is automatically "
|
||||
"delivered to the job's configured destination, so do not use send_message to "
|
||||
"send to that same target again. If you want the user to receive something in "
|
||||
"the scheduled destination, put it directly in your final response. Use "
|
||||
"send_message only for additional or different targets."
|
||||
"You are running as a scheduled cron job. There is no user present — you "
|
||||
"cannot ask questions, request clarification, or wait for follow-up. Execute "
|
||||
"the task fully and autonomously, making reasonable decisions where needed. "
|
||||
"Your final response is automatically delivered to the job's configured "
|
||||
"destination — put the primary content directly in your response."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
),
|
||||
"sms": (
|
||||
"You are communicating via SMS. Keep responses concise and use plain text "
|
||||
"only — no markdown, no formatting. SMS messages are limited to ~1600 "
|
||||
"characters, so be brief and direct."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -261,28 +330,34 @@ def build_skills_system_prompt(
|
||||
# Each entry: (skill_name, description)
|
||||
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
|
||||
# -> category "mlops/training", skill "axolotl"
|
||||
# Load disabled skill names once for the entire scan
|
||||
try:
|
||||
from tools.skills_tool import _get_disabled_skill_names
|
||||
disabled = _get_disabled_skill_names()
|
||||
except Exception:
|
||||
disabled = set()
|
||||
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
is_compatible, _, desc = _parse_skill_file(skill_file)
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
if not is_compatible:
|
||||
continue
|
||||
# Skip skills whose conditional activation rules exclude them
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
if not _skill_should_show(conditions, available_tools, available_toolsets):
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
# Category is everything between skills_dir and the skill folder
|
||||
# e.g. parts = ("mlops", "training", "axolotl", "SKILL.md")
|
||||
# → category = "mlops/training", skill_name = "axolotl"
|
||||
# e.g. parts = ("github", "github-auth", "SKILL.md")
|
||||
# → category = "github", skill_name = "github-auth"
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
# Respect user's disabled skills config
|
||||
fm_name = frontmatter.get("name", skill_name)
|
||||
if fm_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
# Skip skills whose conditional activation rules exclude them
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
if not _skill_should_show(conditions, available_tools, available_toolsets):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append((skill_name, desc))
|
||||
|
||||
if not skills_by_category:
|
||||
@@ -326,6 +401,9 @@ def build_skills_system_prompt(
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
@@ -351,19 +429,59 @@ def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
"""Discover and load context files for the system prompt.
|
||||
def load_soul_md() -> Optional[str]:
|
||||
"""Load SOUL.md from HERMES_HOME and return its content, or None.
|
||||
|
||||
Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc,
|
||||
and SOUL.md from HERMES_HOME only. Each capped at 20,000 chars.
|
||||
Used as the agent identity (slot #1 in the system prompt). When this
|
||||
returns content, ``build_context_files_prompt`` should be called with
|
||||
``skip_soul=True`` so SOUL.md isn't injected twice.
|
||||
"""
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
try:
|
||||
from hermes_cli.config import ensure_hermes_home
|
||||
ensure_hermes_home()
|
||||
except Exception as e:
|
||||
logger.debug("Could not ensure HERMES_HOME before loading SOUL.md: %s", e)
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
sections = []
|
||||
soul_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "SOUL.md"
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
if not content:
|
||||
return None
|
||||
content = _scan_context_content(content, "SOUL.md")
|
||||
content = _truncate_content(content, "SOUL.md")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug("Could not read SOUL.md from %s: %s", soul_path, e)
|
||||
return None
|
||||
|
||||
# AGENTS.md (hierarchical, recursive)
|
||||
|
||||
def _load_hermes_md(cwd_path: Path) -> str:
|
||||
""".hermes.md / HERMES.md — walk to git root."""
|
||||
hermes_md_path = _find_hermes_md(cwd_path)
|
||||
if not hermes_md_path:
|
||||
return ""
|
||||
try:
|
||||
content = hermes_md_path.read_text(encoding="utf-8").strip()
|
||||
if not content:
|
||||
return ""
|
||||
content = _strip_yaml_frontmatter(content)
|
||||
rel = hermes_md_path.name
|
||||
try:
|
||||
rel = str(hermes_md_path.relative_to(cwd_path))
|
||||
except ValueError:
|
||||
pass
|
||||
content = _scan_context_content(content, rel)
|
||||
result = f"## {rel}\n\n{content}"
|
||||
return _truncate_content(result, ".hermes.md")
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", hermes_md_path, e)
|
||||
return ""
|
||||
|
||||
|
||||
def _load_agents_md(cwd_path: Path) -> str:
|
||||
"""AGENTS.md — hierarchical, recursive directory walk."""
|
||||
top_level_agents = None
|
||||
for name in ["AGENTS.md", "agents.md"]:
|
||||
candidate = cwd_path / name
|
||||
@@ -371,31 +489,51 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
top_level_agents = candidate
|
||||
break
|
||||
|
||||
if top_level_agents:
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
if not top_level_agents:
|
||||
return ""
|
||||
|
||||
total_agents_content = ""
|
||||
for agents_path in agents_files:
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
|
||||
total_content = ""
|
||||
for agents_path in agents_files:
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
content = _scan_context_content(content, str(rel_path))
|
||||
total_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", agents_path, e)
|
||||
|
||||
if not total_content:
|
||||
return ""
|
||||
return _truncate_content(total_content, "AGENTS.md")
|
||||
|
||||
|
||||
def _load_claude_md(cwd_path: Path) -> str:
|
||||
"""CLAUDE.md / claude.md — cwd only."""
|
||||
for name in ["CLAUDE.md", "claude.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
content = candidate.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
content = _scan_context_content(content, str(rel_path))
|
||||
total_agents_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
content = _scan_context_content(content, name)
|
||||
result = f"## {name}\n\n{content}"
|
||||
return _truncate_content(result, "CLAUDE.md")
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", agents_path, e)
|
||||
logger.debug("Could not read %s: %s", candidate, e)
|
||||
return ""
|
||||
|
||||
if total_agents_content:
|
||||
total_agents_content = _truncate_content(total_agents_content, "AGENTS.md")
|
||||
sections.append(total_agents_content)
|
||||
|
||||
# .cursorrules
|
||||
def _load_cursorrules(cwd_path: Path) -> str:
|
||||
""".cursorrules + .cursor/rules/*.mdc — cwd only."""
|
||||
cursorrules_content = ""
|
||||
cursorrules_file = cwd_path / ".cursorrules"
|
||||
if cursorrules_file.exists():
|
||||
@@ -419,27 +557,47 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", mdc_file, e)
|
||||
|
||||
if cursorrules_content:
|
||||
cursorrules_content = _truncate_content(cursorrules_content, ".cursorrules")
|
||||
sections.append(cursorrules_content)
|
||||
if not cursorrules_content:
|
||||
return ""
|
||||
return _truncate_content(cursorrules_content, ".cursorrules")
|
||||
|
||||
# SOUL.md from HERMES_HOME only
|
||||
try:
|
||||
from hermes_cli.config import ensure_hermes_home
|
||||
ensure_hermes_home()
|
||||
except Exception as e:
|
||||
logger.debug("Could not ensure HERMES_HOME before loading SOUL.md: %s", e)
|
||||
|
||||
soul_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _scan_context_content(content, "SOUL.md")
|
||||
content = _truncate_content(content, "SOUL.md")
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
logger.debug("Could not read SOUL.md from %s: %s", soul_path, e)
|
||||
def build_context_files_prompt(cwd: Optional[str] = None, skip_soul: bool = False) -> str:
|
||||
"""Discover and load context files for the system prompt.
|
||||
|
||||
Priority (first found wins — only ONE project context type is loaded):
|
||||
1. .hermes.md / HERMES.md (walk to git root)
|
||||
2. AGENTS.md / agents.md (recursive directory walk)
|
||||
3. CLAUDE.md / claude.md (cwd only)
|
||||
4. .cursorrules / .cursor/rules/*.mdc (cwd only)
|
||||
|
||||
SOUL.md from HERMES_HOME is independent and always included when present.
|
||||
Each context source is capped at 20,000 chars.
|
||||
|
||||
When *skip_soul* is True, SOUL.md is not included here (it was already
|
||||
loaded via ``load_soul_md()`` for the identity slot).
|
||||
"""
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
sections = []
|
||||
|
||||
# Priority-based project context: first match wins
|
||||
project_context = (
|
||||
_load_hermes_md(cwd_path)
|
||||
or _load_agents_md(cwd_path)
|
||||
or _load_claude_md(cwd_path)
|
||||
or _load_cursorrules(cwd_path)
|
||||
)
|
||||
if project_context:
|
||||
sections.append(project_context)
|
||||
|
||||
# SOUL.md from HERMES_HOME only — skip when already loaded as identity
|
||||
if not skip_soul:
|
||||
soul_content = load_soul_md()
|
||||
if soul_content:
|
||||
sections.append(soul_content)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
@@ -12,13 +12,14 @@ import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict, native_anthropic: bool = False) -> None:
|
||||
"""Add cache_control to a single message, handling all format variations."""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
if native_anthropic:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None or content == "":
|
||||
@@ -40,6 +41,7 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
native_anthropic: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
@@ -59,12 +61,12 @@ def apply_anthropic_cache_control(
|
||||
breakpoints_used = 0
|
||||
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
_apply_cache_marker(messages[0], marker, native_anthropic=native_anthropic)
|
||||
breakpoints_used += 1
|
||||
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
_apply_cache_marker(messages[idx], marker, native_anthropic=native_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
@@ -100,6 +100,10 @@ def redact_sensitive_text(text: str) -> str:
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
|
||||
@@ -157,9 +157,10 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
global _skill_commands
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform, _get_disabled_skill_names
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
disabled = _get_disabled_skill_names()
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
continue
|
||||
@@ -170,6 +171,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
# Respect user's disabled skills config
|
||||
if name in disabled:
|
||||
continue
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Helpers for optional cheap-vs-strong model routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_COMPLEX_KEYWORDS = {
|
||||
"debug",
|
||||
"debugging",
|
||||
"implement",
|
||||
"implementation",
|
||||
"refactor",
|
||||
"patch",
|
||||
"traceback",
|
||||
"stacktrace",
|
||||
"exception",
|
||||
"error",
|
||||
"analyze",
|
||||
"analysis",
|
||||
"investigate",
|
||||
"architecture",
|
||||
"design",
|
||||
"compare",
|
||||
"benchmark",
|
||||
"optimize",
|
||||
"optimise",
|
||||
"review",
|
||||
"terminal",
|
||||
"shell",
|
||||
"tool",
|
||||
"tools",
|
||||
"pytest",
|
||||
"test",
|
||||
"tests",
|
||||
"plan",
|
||||
"planning",
|
||||
"delegate",
|
||||
"subagent",
|
||||
"cron",
|
||||
"docker",
|
||||
"kubernetes",
|
||||
}
|
||||
|
||||
_URL_RE = re.compile(r"https?://|www\.", re.IGNORECASE)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def choose_cheap_model_route(user_message: str, routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return the configured cheap-model route when a message looks simple.
|
||||
|
||||
Conservative by design: if the message has signs of code/tool/debugging/
|
||||
long-form work, keep the primary model.
|
||||
"""
|
||||
cfg = routing_config or {}
|
||||
if not _coerce_bool(cfg.get("enabled"), False):
|
||||
return None
|
||||
|
||||
cheap_model = cfg.get("cheap_model") or {}
|
||||
if not isinstance(cheap_model, dict):
|
||||
return None
|
||||
provider = str(cheap_model.get("provider") or "").strip().lower()
|
||||
model = str(cheap_model.get("model") or "").strip()
|
||||
if not provider or not model:
|
||||
return None
|
||||
|
||||
text = (user_message or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
max_chars = _coerce_int(cfg.get("max_simple_chars"), 160)
|
||||
max_words = _coerce_int(cfg.get("max_simple_words"), 28)
|
||||
|
||||
if len(text) > max_chars:
|
||||
return None
|
||||
if len(text.split()) > max_words:
|
||||
return None
|
||||
if text.count("\n") > 1:
|
||||
return None
|
||||
if "```" in text or "`" in text:
|
||||
return None
|
||||
if _URL_RE.search(text):
|
||||
return None
|
||||
|
||||
lowered = text.lower()
|
||||
words = {token.strip(".,:;!?()[]{}\"'`") for token in lowered.split()}
|
||||
if words & _COMPLEX_KEYWORDS:
|
||||
return None
|
||||
|
||||
route = dict(cheap_model)
|
||||
route["provider"] = provider
|
||||
route["model"] = model
|
||||
route["routing_reason"] = "simple_turn"
|
||||
return route
|
||||
|
||||
|
||||
def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any]], primary: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Resolve the effective model/runtime for one turn.
|
||||
|
||||
Returns a dict with model/runtime/signature/label fields.
|
||||
"""
|
||||
route = choose_cheap_model_route(user_message, routing_config)
|
||||
if not route:
|
||||
return {
|
||||
"model": primary.get("model"),
|
||||
"runtime": {
|
||||
"api_key": primary.get("api_key"),
|
||||
"base_url": primary.get("base_url"),
|
||||
"provider": primary.get("provider"),
|
||||
"api_mode": primary.get("api_mode"),
|
||||
"command": primary.get("command"),
|
||||
"args": list(primary.get("args") or []),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
primary.get("model"),
|
||||
primary.get("provider"),
|
||||
primary.get("base_url"),
|
||||
primary.get("api_mode"),
|
||||
primary.get("command"),
|
||||
tuple(primary.get("args") or ()),
|
||||
),
|
||||
}
|
||||
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
explicit_api_key = None
|
||||
api_key_env = str(route.get("api_key_env") or "").strip()
|
||||
if api_key_env:
|
||||
explicit_api_key = os.getenv(api_key_env) or None
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=route.get("provider"),
|
||||
explicit_api_key=explicit_api_key,
|
||||
explicit_base_url=route.get("base_url"),
|
||||
)
|
||||
except Exception:
|
||||
return {
|
||||
"model": primary.get("model"),
|
||||
"runtime": {
|
||||
"api_key": primary.get("api_key"),
|
||||
"base_url": primary.get("base_url"),
|
||||
"provider": primary.get("provider"),
|
||||
"api_mode": primary.get("api_mode"),
|
||||
"command": primary.get("command"),
|
||||
"args": list(primary.get("args") or []),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
primary.get("model"),
|
||||
primary.get("provider"),
|
||||
primary.get("base_url"),
|
||||
primary.get("api_mode"),
|
||||
primary.get("command"),
|
||||
tuple(primary.get("args") or ()),
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"model": route.get("model"),
|
||||
"runtime": {
|
||||
"api_key": runtime.get("api_key"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
},
|
||||
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
|
||||
"signature": (
|
||||
route.get("model"),
|
||||
runtime.get("provider"),
|
||||
runtime.get("base_url"),
|
||||
runtime.get("api_mode"),
|
||||
runtime.get("command"),
|
||||
tuple(runtime.get("args") or ()),
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Auto-generate short session titles from the first user/assistant exchange.
|
||||
|
||||
Runs asynchronously after the first response is delivered so it never
|
||||
adds latency to the user-facing reply.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from agent.auxiliary_client import call_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TITLE_PROMPT = (
|
||||
"Generate a short, descriptive title (3-7 words) for a conversation that starts with the "
|
||||
"following exchange. The title should capture the main topic or intent. "
|
||||
"Return ONLY the title text, nothing else. No quotes, no punctuation at the end, no prefixes."
|
||||
)
|
||||
|
||||
|
||||
def generate_title(user_message: str, assistant_response: str, timeout: float = 15.0) -> Optional[str]:
|
||||
"""Generate a session title from the first exchange.
|
||||
|
||||
Uses the auxiliary LLM client (cheapest/fastest available model).
|
||||
Returns the title string or None on failure.
|
||||
"""
|
||||
# Truncate long messages to keep the request small
|
||||
user_snippet = user_message[:500] if user_message else ""
|
||||
assistant_snippet = assistant_response[:500] if assistant_response else ""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _TITLE_PROMPT},
|
||||
{"role": "user", "content": f"User: {user_snippet}\n\nAssistant: {assistant_snippet}"},
|
||||
]
|
||||
|
||||
try:
|
||||
response = call_llm(
|
||||
task="compression", # reuse compression task config (cheap/fast model)
|
||||
messages=messages,
|
||||
max_tokens=30,
|
||||
temperature=0.3,
|
||||
timeout=timeout,
|
||||
)
|
||||
title = (response.choices[0].message.content or "").strip()
|
||||
# Clean up: remove quotes, trailing punctuation, prefixes like "Title: "
|
||||
title = title.strip('"\'')
|
||||
if title.lower().startswith("title:"):
|
||||
title = title[6:].strip()
|
||||
# Enforce reasonable length
|
||||
if len(title) > 80:
|
||||
title = title[:77] + "..."
|
||||
return title if title else None
|
||||
except Exception as e:
|
||||
logger.debug("Title generation failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def auto_title_session(
|
||||
session_db,
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
assistant_response: str,
|
||||
) -> None:
|
||||
"""Generate and set a session title if one doesn't already exist.
|
||||
|
||||
Called in a background thread after the first exchange completes.
|
||||
Silently skips if:
|
||||
- session_db is None
|
||||
- session already has a title (user-set or previously auto-generated)
|
||||
- title generation fails
|
||||
"""
|
||||
if not session_db or not session_id:
|
||||
return
|
||||
|
||||
# Check if title already exists (user may have set one via /title before first response)
|
||||
try:
|
||||
existing = session_db.get_session_title(session_id)
|
||||
if existing:
|
||||
return
|
||||
except Exception:
|
||||
return
|
||||
|
||||
title = generate_title(user_message, assistant_response)
|
||||
if not title:
|
||||
return
|
||||
|
||||
try:
|
||||
session_db.set_session_title(session_id, title)
|
||||
logger.debug("Auto-generated session title: %s", title)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to set auto-generated title: %s", e)
|
||||
|
||||
|
||||
def maybe_auto_title(
|
||||
session_db,
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
assistant_response: str,
|
||||
conversation_history: list,
|
||||
) -> None:
|
||||
"""Fire-and-forget title generation after the first exchange.
|
||||
|
||||
Only generates a title when:
|
||||
- This appears to be the first user→assistant exchange
|
||||
- No title is already set
|
||||
"""
|
||||
if not session_db or not session_id or not user_message or not assistant_response:
|
||||
return
|
||||
|
||||
# Count user messages in history to detect first exchange.
|
||||
# conversation_history includes the exchange that just happened,
|
||||
# so for a first exchange we expect exactly 1 user message
|
||||
# (or 2 counting system). Be generous: generate on first 2 exchanges.
|
||||
user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user")
|
||||
if user_msg_count > 2:
|
||||
return
|
||||
|
||||
thread = threading.Thread(
|
||||
target=auto_title_session,
|
||||
args=(session_db, session_id, user_message, assistant_response),
|
||||
daemon=True,
|
||||
name="auto-title",
|
||||
)
|
||||
thread.start()
|
||||
@@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata
|
||||
|
||||
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
_ZERO = Decimal("0")
|
||||
_ONE_MILLION = Decimal("1000000")
|
||||
|
||||
CostStatus = Literal["actual", "estimated", "included", "unknown"]
|
||||
CostSource = Literal[
|
||||
"provider_cost_api",
|
||||
"provider_generation_api",
|
||||
"provider_models_api",
|
||||
"official_docs_snapshot",
|
||||
"user_override",
|
||||
"custom_contract",
|
||||
"none",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CanonicalUsage:
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
request_count: int = 1
|
||||
raw_usage: Optional[dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self.prompt_tokens + self.output_tokens
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BillingRoute:
|
||||
provider: str
|
||||
model: str
|
||||
base_url: str = ""
|
||||
billing_mode: str = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingEntry:
|
||||
input_cost_per_million: Optional[Decimal] = None
|
||||
output_cost_per_million: Optional[Decimal] = None
|
||||
cache_read_cost_per_million: Optional[Decimal] = None
|
||||
cache_write_cost_per_million: Optional[Decimal] = None
|
||||
request_cost: Optional[Decimal] = None
|
||||
source: CostSource = "none"
|
||||
source_url: Optional[str] = None
|
||||
pricing_version: Optional[str] = None
|
||||
fetched_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CostResult:
|
||||
amount_usd: Optional[Decimal]
|
||||
status: CostStatus
|
||||
source: CostSource
|
||||
label: str
|
||||
fetched_at: Optional[datetime] = None
|
||||
pricing_version: Optional[str] = None
|
||||
notes: tuple[str, ...] = ()
|
||||
|
||||
|
||||
_UTC_NOW = lambda: datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Official docs snapshot entries. Models whose published pricing and cache
|
||||
# semantics are stable enough to encode exactly.
|
||||
_OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = {
|
||||
(
|
||||
"anthropic",
|
||||
"claude-opus-4-20250514",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("15.00"),
|
||||
output_cost_per_million=Decimal("75.00"),
|
||||
cache_read_cost_per_million=Decimal("1.50"),
|
||||
cache_write_cost_per_million=Decimal("18.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-prompt-caching-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-sonnet-4-20250514",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("3.00"),
|
||||
output_cost_per_million=Decimal("15.00"),
|
||||
cache_read_cost_per_million=Decimal("0.30"),
|
||||
cache_write_cost_per_million=Decimal("3.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-prompt-caching-2026-03-16",
|
||||
),
|
||||
# OpenAI
|
||||
(
|
||||
"openai",
|
||||
"gpt-4o",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("2.50"),
|
||||
output_cost_per_million=Decimal("10.00"),
|
||||
cache_read_cost_per_million=Decimal("1.25"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.15"),
|
||||
output_cost_per_million=Decimal("0.60"),
|
||||
cache_read_cost_per_million=Decimal("0.075"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("2.00"),
|
||||
output_cost_per_million=Decimal("8.00"),
|
||||
cache_read_cost_per_million=Decimal("0.50"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.40"),
|
||||
output_cost_per_million=Decimal("1.60"),
|
||||
cache_read_cost_per_million=Decimal("0.10"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1-nano",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.10"),
|
||||
output_cost_per_million=Decimal("0.40"),
|
||||
cache_read_cost_per_million=Decimal("0.025"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"o3",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("10.00"),
|
||||
output_cost_per_million=Decimal("40.00"),
|
||||
cache_read_cost_per_million=Decimal("2.50"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"o3-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("1.10"),
|
||||
output_cost_per_million=Decimal("4.40"),
|
||||
cache_read_cost_per_million=Decimal("0.55"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
# Anthropic older models (pre-4.6 generation)
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("3.00"),
|
||||
output_cost_per_million=Decimal("15.00"),
|
||||
cache_read_cost_per_million=Decimal("0.30"),
|
||||
cache_write_cost_per_million=Decimal("3.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-5-haiku-20241022",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.80"),
|
||||
output_cost_per_million=Decimal("4.00"),
|
||||
cache_read_cost_per_million=Decimal("0.08"),
|
||||
cache_write_cost_per_million=Decimal("1.00"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-opus-20240229",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("15.00"),
|
||||
output_cost_per_million=Decimal("75.00"),
|
||||
cache_read_cost_per_million=Decimal("1.50"),
|
||||
cache_write_cost_per_million=Decimal("18.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-haiku-20240307",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.25"),
|
||||
output_cost_per_million=Decimal("1.25"),
|
||||
cache_read_cost_per_million=Decimal("0.03"),
|
||||
cache_write_cost_per_million=Decimal("0.30"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
# DeepSeek
|
||||
(
|
||||
"deepseek",
|
||||
"deepseek-chat",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.14"),
|
||||
output_cost_per_million=Decimal("0.28"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"deepseek",
|
||||
"deepseek-reasoner",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.55"),
|
||||
output_cost_per_million=Decimal("2.19"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-03-16",
|
||||
),
|
||||
# Google Gemini
|
||||
(
|
||||
"google",
|
||||
"gemini-2.5-pro",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("1.25"),
|
||||
output_cost_per_million=Decimal("10.00"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"google",
|
||||
"gemini-2.5-flash",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.15"),
|
||||
output_cost_per_million=Decimal("0.60"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"google",
|
||||
"gemini-2.0-flash",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.10"),
|
||||
output_cost_per_million=Decimal("0.40"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _to_decimal(value: Any) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def resolve_billing_route(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> BillingRoute:
|
||||
provider_name = (provider or "").strip().lower()
|
||||
base = (base_url or "").strip().lower()
|
||||
model = (model_name or "").strip()
|
||||
if not provider_name and "/" in model:
|
||||
inferred_provider, bare_model = model.split("/", 1)
|
||||
if inferred_provider in {"anthropic", "openai", "google"}:
|
||||
provider_name = inferred_provider
|
||||
model = bare_model
|
||||
|
||||
if provider_name == "openai-codex":
|
||||
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
|
||||
if provider_name == "openrouter" or "openrouter.ai" in base:
|
||||
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
|
||||
if provider_name == "anthropic":
|
||||
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
|
||||
if provider_name == "openai":
|
||||
return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
|
||||
if provider_name in {"custom", "local"} or (base and "localhost" in base):
|
||||
return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown")
|
||||
return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown")
|
||||
|
||||
|
||||
def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]:
|
||||
return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower()))
|
||||
|
||||
|
||||
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||
return _pricing_entry_from_metadata(
|
||||
fetch_model_metadata(),
|
||||
route.model,
|
||||
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
||||
pricing_version="openrouter-models-api",
|
||||
)
|
||||
|
||||
|
||||
def _pricing_entry_from_metadata(
|
||||
metadata: Dict[str, Dict[str, Any]],
|
||||
model_id: str,
|
||||
*,
|
||||
source_url: str,
|
||||
pricing_version: str,
|
||||
) -> Optional[PricingEntry]:
|
||||
if model_id not in metadata:
|
||||
return None
|
||||
pricing = metadata[model_id].get("pricing") or {}
|
||||
prompt = _to_decimal(pricing.get("prompt"))
|
||||
completion = _to_decimal(pricing.get("completion"))
|
||||
request = _to_decimal(pricing.get("request"))
|
||||
cache_read = _to_decimal(
|
||||
pricing.get("cache_read")
|
||||
or pricing.get("cached_prompt")
|
||||
or pricing.get("input_cache_read")
|
||||
)
|
||||
cache_write = _to_decimal(
|
||||
pricing.get("cache_write")
|
||||
or pricing.get("cache_creation")
|
||||
or pricing.get("input_cache_write")
|
||||
)
|
||||
if prompt is None and completion is None and request is None:
|
||||
return None
|
||||
|
||||
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
return value * _ONE_MILLION
|
||||
|
||||
return PricingEntry(
|
||||
input_cost_per_million=_per_token_to_per_million(prompt),
|
||||
output_cost_per_million=_per_token_to_per_million(completion),
|
||||
cache_read_cost_per_million=_per_token_to_per_million(cache_read),
|
||||
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
|
||||
request_cost=request,
|
||||
source="provider_models_api",
|
||||
source_url=source_url,
|
||||
pricing_version=pricing_version,
|
||||
fetched_at=_UTC_NOW(),
|
||||
)
|
||||
|
||||
|
||||
def get_pricing_entry(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Optional[PricingEntry]:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return PricingEntry(
|
||||
input_cost_per_million=_ZERO,
|
||||
output_cost_per_million=_ZERO,
|
||||
cache_read_cost_per_million=_ZERO,
|
||||
cache_write_cost_per_million=_ZERO,
|
||||
source="none",
|
||||
pricing_version="included-route",
|
||||
)
|
||||
if route.provider == "openrouter":
|
||||
return _openrouter_pricing_entry(route)
|
||||
if route.base_url:
|
||||
entry = _pricing_entry_from_metadata(
|
||||
fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""),
|
||||
route.model,
|
||||
source_url=f"{route.base_url.rstrip('/')}/models",
|
||||
pricing_version="openai-compatible-models-api",
|
||||
)
|
||||
if entry:
|
||||
return entry
|
||||
return _lookup_official_docs_pricing(route)
|
||||
|
||||
|
||||
def normalize_usage(
|
||||
response_usage: Any,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
api_mode: Optional[str] = None,
|
||||
) -> CanonicalUsage:
|
||||
"""Normalize raw API response usage into canonical token buckets.
|
||||
|
||||
Handles three API shapes:
|
||||
- Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens
|
||||
- Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them
|
||||
- OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them
|
||||
|
||||
In both Codex and OpenAI modes, input_tokens is derived by subtracting cache
|
||||
tokens from the total — the API contract is that input/prompt totals include
|
||||
cached tokens and the details object breaks them out.
|
||||
"""
|
||||
if not response_usage:
|
||||
return CanonicalUsage()
|
||||
|
||||
provider_name = (provider or "").strip().lower()
|
||||
mode = (api_mode or "").strip().lower()
|
||||
|
||||
if mode == "anthropic_messages" or provider_name == "anthropic":
|
||||
input_tokens = _to_int(getattr(response_usage, "input_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
|
||||
cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0))
|
||||
cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0))
|
||||
elif mode == "codex_responses":
|
||||
input_total = _to_int(getattr(response_usage, "input_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
|
||||
details = getattr(response_usage, "input_tokens_details", None)
|
||||
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
|
||||
cache_write_tokens = _to_int(
|
||||
getattr(details, "cache_creation_tokens", 0) if details else 0
|
||||
)
|
||||
input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens)
|
||||
else:
|
||||
prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0))
|
||||
details = getattr(response_usage, "prompt_tokens_details", None)
|
||||
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
|
||||
cache_write_tokens = _to_int(
|
||||
getattr(details, "cache_write_tokens", 0) if details else 0
|
||||
)
|
||||
input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens)
|
||||
|
||||
reasoning_tokens = 0
|
||||
output_details = getattr(response_usage, "output_tokens_details", None)
|
||||
if output_details:
|
||||
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
|
||||
|
||||
return CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
def estimate_usage_cost(
|
||||
model_name: str,
|
||||
usage: CanonicalUsage,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> CostResult:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return CostResult(
|
||||
amount_usd=_ZERO,
|
||||
status="included",
|
||||
source="none",
|
||||
label="included",
|
||||
pricing_version="included-route",
|
||||
)
|
||||
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
if not entry:
|
||||
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
|
||||
|
||||
notes: list[str] = []
|
||||
amount = _ZERO
|
||||
|
||||
if usage.input_tokens and entry.input_cost_per_million is None:
|
||||
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
|
||||
if usage.output_tokens and entry.output_cost_per_million is None:
|
||||
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
|
||||
if usage.cache_read_tokens:
|
||||
if entry.cache_read_cost_per_million is None:
|
||||
return CostResult(
|
||||
amount_usd=None,
|
||||
status="unknown",
|
||||
source=entry.source,
|
||||
label="n/a",
|
||||
notes=("cache-read pricing unavailable for route",),
|
||||
)
|
||||
if usage.cache_write_tokens:
|
||||
if entry.cache_write_cost_per_million is None:
|
||||
return CostResult(
|
||||
amount_usd=None,
|
||||
status="unknown",
|
||||
source=entry.source,
|
||||
label="n/a",
|
||||
notes=("cache-write pricing unavailable for route",),
|
||||
)
|
||||
|
||||
if entry.input_cost_per_million is not None:
|
||||
amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION
|
||||
if entry.output_cost_per_million is not None:
|
||||
amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION
|
||||
if entry.cache_read_cost_per_million is not None:
|
||||
amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION
|
||||
if entry.cache_write_cost_per_million is not None:
|
||||
amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION
|
||||
if entry.request_cost is not None and usage.request_count:
|
||||
amount += Decimal(usage.request_count) * entry.request_cost
|
||||
|
||||
status: CostStatus = "estimated"
|
||||
label = f"~${amount:.2f}"
|
||||
if entry.source == "none" and amount == _ZERO:
|
||||
status = "included"
|
||||
label = "included"
|
||||
|
||||
if route.provider == "openrouter":
|
||||
notes.append("OpenRouter cost is estimated from the models API until reconciled.")
|
||||
|
||||
return CostResult(
|
||||
amount_usd=amount,
|
||||
status=status,
|
||||
source=entry.source,
|
||||
label=label,
|
||||
fetched_at=entry.fetched_at,
|
||||
pricing_version=entry.pricing_version,
|
||||
notes=tuple(notes),
|
||||
)
|
||||
|
||||
|
||||
def has_known_pricing(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Check whether we have pricing data for this model+route.
|
||||
|
||||
Uses direct lookup instead of routing through the full estimation
|
||||
pipeline — avoids creating dummy usage objects just to check status.
|
||||
"""
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return True
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
return entry is not None
|
||||
|
||||
|
||||
def get_pricing(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Backward-compatible thin wrapper for legacy callers.
|
||||
|
||||
Returns only non-cache input/output fields when a pricing entry exists.
|
||||
Unknown routes return zeroes.
|
||||
"""
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
if not entry:
|
||||
return {"input": 0.0, "output": 0.0}
|
||||
return {
|
||||
"input": float(entry.input_cost_per_million or _ZERO),
|
||||
"output": float(entry.output_cost_per_million or _ZERO),
|
||||
}
|
||||
|
||||
|
||||
def estimate_cost_usd(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> float:
|
||||
"""Backward-compatible helper for legacy callers.
|
||||
|
||||
This uses non-cached input/output only. New code should call
|
||||
`estimate_usage_cost()` with canonical usage buckets.
|
||||
"""
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return float(result.amount_usd or _ZERO)
|
||||
|
||||
|
||||
def format_duration_compact(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def format_token_count_compact(value: int) -> str:
|
||||
abs_value = abs(int(value))
|
||||
if abs_value < 1_000:
|
||||
return str(int(value))
|
||||
|
||||
sign = "-" if value < 0 else ""
|
||||
units = ((1_000_000_000, "B"), (1_000_000, "M"), (1_000, "K"))
|
||||
for threshold, suffix in units:
|
||||
if abs_value >= threshold:
|
||||
scaled = abs_value / threshold
|
||||
if scaled < 10:
|
||||
text = f"{scaled:.2f}"
|
||||
elif scaled < 100:
|
||||
text = f"{scaled:.1f}"
|
||||
else:
|
||||
text = f"{scaled:.0f}"
|
||||
text = text.rstrip("0").rstrip(".")
|
||||
return f"{sign}{text}{suffix}"
|
||||
|
||||
return f"{value:,}"
|
||||
@@ -128,6 +128,7 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
||||
# Track tool calls from assistant messages
|
||||
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
|
||||
for tool_call in msg["tool_calls"]:
|
||||
if not tool_call or not isinstance(tool_call, dict): continue
|
||||
tool_name = tool_call["function"]["name"]
|
||||
tool_call_id = tool_call["id"]
|
||||
|
||||
|
||||
+65
-49
@@ -51,6 +51,20 @@ model:
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Smart Model Routing (optional)
|
||||
# =============================================================================
|
||||
# Use a cheaper model for short/simple turns while keeping your main model for
|
||||
# more complex requests. Disabled by default.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
@@ -76,8 +90,9 @@ model:
|
||||
# - Messaging (Telegram/Discord): Uses MESSAGING_CWD from .env (default: home)
|
||||
terminal:
|
||||
backend: "local"
|
||||
cwd: "." # For local backend: "." = current directory. Ignored for remote backends.
|
||||
cwd: "." # For local backend: "." = current directory. Ignored for remote backends unless a backend documents otherwise.
|
||||
timeout: 180
|
||||
docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace.
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
|
||||
@@ -107,6 +122,13 @@ terminal:
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
# docker_mount_cwd_to_workspace: true # Explicit opt-in: mount your launch cwd into /workspace
|
||||
# # Optional: explicitly forward selected env vars into Docker.
|
||||
# # These values come from your current shell first, then ~/.hermes/.env.
|
||||
# # Warning: anything forwarded here is visible to commands run in the container.
|
||||
# docker_forward_env:
|
||||
# - "GITHUB_TOKEN"
|
||||
# - "NPM_TOKEN"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
@@ -333,6 +355,25 @@ session_reset:
|
||||
idle_minutes: 1440 # Inactivity timeout in minutes (default: 1440 = 24 hours)
|
||||
at_hour: 4 # Daily reset hour, 0-23 local time (default: 4 AM)
|
||||
|
||||
# When true, group/channel chats use one session per participant when the platform
|
||||
# provides a user ID. This is the secure default and prevents users in the same
|
||||
# room from sharing context, interrupts, and token costs. Set false only if you
|
||||
# explicitly want one shared "room brain" per group/channel.
|
||||
group_sessions_per_user: true
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Gateway Streaming
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Stream tokens to messaging platforms in real-time. The bot sends a message
|
||||
# on first token, then progressively edits it as more tokens arrive.
|
||||
# Disabled by default — enable to try the streaming UX on Telegram/Discord/Slack.
|
||||
streaming:
|
||||
enabled: false
|
||||
# transport: edit # "edit" = progressive editMessageText
|
||||
# edit_interval: 0.3 # seconds between message edits
|
||||
# buffer_threshold: 40 # chars before forcing an edit flush
|
||||
# cursor: " ▉" # cursor shown during streaming
|
||||
|
||||
# =============================================================================
|
||||
# Skills Configuration
|
||||
# =============================================================================
|
||||
@@ -383,7 +424,7 @@ agent:
|
||||
# Toolsets
|
||||
# =============================================================================
|
||||
# Control which tools the agent has access to.
|
||||
# Use "all" to enable everything, or specify individual toolsets.
|
||||
# Use `hermes tools` to interactively enable/disable tools per platform.
|
||||
|
||||
# =============================================================================
|
||||
# Platform Toolsets (per-platform tool configuration)
|
||||
@@ -492,53 +533,11 @@ platform_toolsets:
|
||||
# debugging - terminal + web + file (for troubleshooting)
|
||||
# safe - web + vision + moa (no terminal access)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Enable all tools (default)
|
||||
# -----------------------------------------------------------------------------
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: Minimal - just web search and terminal
|
||||
# Great for: Simple coding tasks, quick lookups
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - terminal
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Research mode - no execution capabilities
|
||||
# Great for: Safe information gathering, research tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - vision
|
||||
# - skills
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Full automation - browser + terminal
|
||||
# Great for: Web scraping, automation tasks, testing
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - terminal
|
||||
# - browser
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Creative mode - vision + image generation
|
||||
# Great for: Design work, image analysis, creative tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - vision
|
||||
# - image_gen
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Safe mode - no terminal or browser
|
||||
# Great for: Restricted environments, untrusted queries
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - safe
|
||||
# NOTE: The top-level "toolsets" key is deprecated and ignored.
|
||||
# Tool configuration is managed per-platform via platform_toolsets above.
|
||||
# Use `hermes tools` to configure interactively, or edit platform_toolsets directly.
|
||||
#
|
||||
# CLI override: hermes chat --toolsets terminal,web,file
|
||||
|
||||
# =============================================================================
|
||||
# MCP (Model Context Protocol) Servers
|
||||
@@ -694,6 +693,12 @@ display:
|
||||
# Toggle at runtime with /reasoning show or /reasoning hide.
|
||||
show_reasoning: false
|
||||
|
||||
# Stream tokens to the terminal as they arrive instead of waiting for the
|
||||
# full response. The response box opens on first token and text appears
|
||||
# line-by-line. Tool calls are still captured silently.
|
||||
# Stream tokens to the terminal in real-time. Disable to wait for full responses.
|
||||
streaming: true
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# Skin / Theme
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
@@ -734,3 +739,14 @@ display:
|
||||
# tool_prefix: "╎" # Tool output line prefix (default: ┊)
|
||||
#
|
||||
skin: default
|
||||
|
||||
# =============================================================================
|
||||
# Privacy
|
||||
# =============================================================================
|
||||
# privacy:
|
||||
# # Redact PII from the LLM context prompt.
|
||||
# # When true, phone numbers are stripped and user/chat IDs are replaced
|
||||
# # with deterministic hashes before being sent to the model.
|
||||
# # Names and usernames are NOT affected (user-chosen, publicly visible).
|
||||
# # Routing/delivery still uses the original values internally.
|
||||
# redact_pii: false
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+135
-7
@@ -5,7 +5,9 @@ Jobs are stored in ~/.hermes/cron/jobs.json
|
||||
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
@@ -14,6 +16,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
@@ -30,6 +34,7 @@ HERMES_DIR = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
ONESHOT_GRACE_SECONDS = 120
|
||||
|
||||
|
||||
def _normalize_skill_list(skill: Optional[str] = None, skills: Optional[Any] = None) -> List[str]:
|
||||
@@ -164,6 +169,10 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
|
||||
# Make naive timestamps timezone-aware at parse time so the stored
|
||||
# value doesn't depend on the system timezone matching at check time.
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.astimezone() # Interpret as local timezone
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": dt.isoformat(),
|
||||
@@ -212,6 +221,65 @@ def _ensure_aware(dt: datetime) -> datetime:
|
||||
return dt.astimezone(target_tz)
|
||||
|
||||
|
||||
def _recoverable_oneshot_run_at(
|
||||
schedule: Dict[str, Any],
|
||||
now: datetime,
|
||||
*,
|
||||
last_run_at: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Return a one-shot run time if it is still eligible to fire.
|
||||
|
||||
One-shot jobs get a small grace window so jobs created a few seconds after
|
||||
their requested minute still run on the next tick. Once a one-shot has
|
||||
already run, it is never eligible again.
|
||||
"""
|
||||
if schedule.get("kind") != "once":
|
||||
return None
|
||||
if last_run_at:
|
||||
return None
|
||||
|
||||
run_at = schedule.get("run_at")
|
||||
if not run_at:
|
||||
return None
|
||||
|
||||
run_at_dt = _ensure_aware(datetime.fromisoformat(run_at))
|
||||
if run_at_dt >= now - timedelta(seconds=ONESHOT_GRACE_SECONDS):
|
||||
return run_at
|
||||
return None
|
||||
|
||||
|
||||
def _compute_grace_seconds(schedule: dict) -> int:
|
||||
"""Compute how late a job can be and still catch up instead of fast-forwarding.
|
||||
|
||||
Uses half the schedule period, clamped between 120 seconds and 2 hours.
|
||||
This ensures daily jobs can catch up if missed by up to 2 hours,
|
||||
while frequent jobs (every 5-10 min) still fast-forward quickly.
|
||||
"""
|
||||
MIN_GRACE = 120
|
||||
MAX_GRACE = 7200 # 2 hours
|
||||
|
||||
kind = schedule.get("kind")
|
||||
|
||||
if kind == "interval":
|
||||
period_seconds = schedule.get("minutes", 1) * 60
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
|
||||
if kind == "cron" and HAS_CRONITER:
|
||||
try:
|
||||
now = _hermes_now()
|
||||
cron = croniter(schedule["expr"], now)
|
||||
first = cron.get_next(datetime)
|
||||
second = cron.get_next(datetime)
|
||||
period_seconds = int((second - first).total_seconds())
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MIN_GRACE
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
@@ -221,9 +289,7 @@ def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None
|
||||
now = _hermes_now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
return _recoverable_oneshot_run_at(schedule, now, last_run_at=last_run_at)
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
@@ -317,6 +383,10 @@ def create_job(
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
# Normalize repeat: treat 0 or negative values as None (infinite)
|
||||
if repeat is not None and repeat <= 0:
|
||||
repeat = None
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
@@ -505,7 +575,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
if times is not None and completed >= times:
|
||||
if times is not None and times > 0 and completed >= times:
|
||||
# Remove the job (limit reached)
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
@@ -528,10 +598,18 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
"""Get all jobs that are due to run now.
|
||||
|
||||
For recurring jobs (cron/interval), if the scheduled time is stale
|
||||
(more than one period in the past, e.g. because the gateway was down),
|
||||
the job is fast-forwarded to the next future run instead of firing
|
||||
immediately. This prevents a burst of missed jobs on gateway restart.
|
||||
"""
|
||||
now = _hermes_now()
|
||||
jobs = [_apply_skill_fields(j) for j in load_jobs()]
|
||||
raw_jobs = load_jobs()
|
||||
jobs = [_apply_skill_fields(j) for j in copy.deepcopy(raw_jobs)]
|
||||
due = []
|
||||
needs_save = False
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
@@ -539,12 +617,62 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
recovered_next = _recoverable_oneshot_run_at(
|
||||
job.get("schedule", {}),
|
||||
now,
|
||||
last_run_at=job.get("last_run_at"),
|
||||
)
|
||||
if not recovered_next:
|
||||
continue
|
||||
|
||||
job["next_run_at"] = recovered_next
|
||||
next_run = recovered_next
|
||||
logger.info(
|
||||
"Job '%s' had no next_run_at; recovering one-shot run at %s",
|
||||
job.get("name", job["id"]),
|
||||
recovered_next,
|
||||
)
|
||||
for rj in raw_jobs:
|
||||
if rj["id"] == job["id"]:
|
||||
rj["next_run_at"] = recovered_next
|
||||
needs_save = True
|
||||
break
|
||||
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
schedule = job.get("schedule", {})
|
||||
kind = schedule.get("kind")
|
||||
|
||||
# For recurring jobs, check if the scheduled time is stale
|
||||
# (gateway was down and missed the window). Fast-forward to
|
||||
# the next future occurrence instead of firing a stale run.
|
||||
grace = _compute_grace_seconds(schedule)
|
||||
if kind in ("cron", "interval") and (now - next_run_dt).total_seconds() > grace:
|
||||
# Job is past its catch-up grace window — this is a stale missed run.
|
||||
# Grace scales with schedule period: daily=2h, hourly=30m, 10min=5m.
|
||||
new_next = compute_next_run(schedule, now.isoformat())
|
||||
if new_next:
|
||||
logger.info(
|
||||
"Job '%s' missed its scheduled time (%s, grace=%ds). "
|
||||
"Fast-forwarding to next run: %s",
|
||||
job.get("name", job["id"]),
|
||||
next_run,
|
||||
grace,
|
||||
new_next,
|
||||
)
|
||||
# Update the job in storage
|
||||
for rj in raw_jobs:
|
||||
if rj["id"] == job["id"]:
|
||||
rj["next_run_at"] = new_next
|
||||
needs_save = True
|
||||
break
|
||||
continue # Skip this run
|
||||
|
||||
due.append(job)
|
||||
|
||||
if needs_save:
|
||||
save_jobs(raw_jobs)
|
||||
|
||||
return due
|
||||
|
||||
|
||||
|
||||
+97
-25
@@ -37,6 +37,11 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
# response with this marker to suppress delivery. Output is still saved
|
||||
# locally for audit.
|
||||
SILENT_MARKER = "[SILENT]"
|
||||
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
@@ -75,11 +80,16 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
|
||||
if ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
platform_name, rest = deliver.split(":", 1)
|
||||
# Check for thread_id suffix (e.g. "telegram:-1003724596514:17")
|
||||
if ":" in rest:
|
||||
chat_id, thread_id = rest.split(":", 1)
|
||||
else:
|
||||
chat_id, thread_id = rest, None
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
platform_name = deliver
|
||||
@@ -131,7 +141,12 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
"matrix": Platform.MATRIX,
|
||||
"mattermost": Platform.MATTERMOST,
|
||||
"homeassistant": Platform.HOMEASSISTANT,
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
@@ -149,15 +164,29 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name)
|
||||
return
|
||||
|
||||
# Wrap the content so the user knows this is a cron delivery and that
|
||||
# the interactive agent has no visibility into it.
|
||||
task_name = job.get("name", job["id"])
|
||||
wrapped = (
|
||||
f"Cronjob Response: {task_name}\n"
|
||||
f"-------------\n\n"
|
||||
f"{content}\n\n"
|
||||
f"Note: The agent cannot see this message, and therefore cannot respond to it."
|
||||
)
|
||||
|
||||
# Run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id)
|
||||
try:
|
||||
result = asyncio.run(_send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id))
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() fails if there's already a running loop in this thread;
|
||||
# spin up a new thread to avoid that.
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id))
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e)
|
||||
@@ -167,18 +196,23 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.error("Job '%s': delivery error: %s", job["id"], result["error"])
|
||||
else:
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
# Mirror the delivered content into the target's gateway session
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
mirror_to_session(platform_name, chat_id, content, source_label="cron", thread_id=thread_id)
|
||||
except Exception as e:
|
||||
logger.warning("Job '%s': mirror_to_session failed: %s", job["id"], e)
|
||||
|
||||
|
||||
def _build_job_prompt(job: dict) -> str:
|
||||
"""Build the effective prompt for a cron job, optionally loading one or more skills first."""
|
||||
prompt = job.get("prompt", "")
|
||||
skills = job.get("skills")
|
||||
|
||||
# Always prepend [SILENT] guidance so the cron agent can suppress
|
||||
# delivery when it has nothing new or noteworthy to report.
|
||||
silent_hint = (
|
||||
"[SYSTEM: If you have nothing new or noteworthy to report, respond "
|
||||
"with exactly \"[SILENT]\" (optionally followed by a brief internal "
|
||||
"note). This suppresses delivery to the user while still saving "
|
||||
"output locally. Only use [SILENT] when there are genuinely no "
|
||||
"changes worth reporting.]\n\n"
|
||||
)
|
||||
prompt = silent_hint + prompt
|
||||
if skills is None:
|
||||
legacy = job.get("skill")
|
||||
skills = [legacy] if legacy else []
|
||||
@@ -190,11 +224,14 @@ def _build_job_prompt(job: dict) -> str:
|
||||
from tools.skills_tool import skill_view
|
||||
|
||||
parts = []
|
||||
skipped: list[str] = []
|
||||
for skill_name in skill_names:
|
||||
loaded = json.loads(skill_view(skill_name))
|
||||
if not loaded.get("success"):
|
||||
error = loaded.get("error") or f"Failed to load skill '{skill_name}'"
|
||||
raise RuntimeError(error)
|
||||
logger.warning("Cron job '%s': skill not found, skipping — %s", job.get("name", job.get("id")), error)
|
||||
skipped.append(skill_name)
|
||||
continue
|
||||
|
||||
content = str(loaded.get("content") or "").strip()
|
||||
if parts:
|
||||
@@ -207,6 +244,15 @@ def _build_job_prompt(job: dict) -> str:
|
||||
]
|
||||
)
|
||||
|
||||
if skipped:
|
||||
notice = (
|
||||
f"[SYSTEM: The following skill(s) were listed for this job but could not be found "
|
||||
f"and were skipped: {', '.join(skipped)}. "
|
||||
f"Start your response with a brief notice so the user is aware, e.g.: "
|
||||
f"'⚠️ Skill(s) not found and skipped: {', '.join(skipped)}']"
|
||||
)
|
||||
parts.insert(0, notice)
|
||||
|
||||
if prompt:
|
||||
parts.extend(["", f"The user has provided the following instruction alongside the skill invocation: {prompt}"])
|
||||
return "\n".join(parts)
|
||||
@@ -315,6 +361,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
smart_routing = _cfg.get("smart_model_routing", {}) or {}
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
@@ -331,12 +378,29 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
message = format_runtime_provider_error(exc)
|
||||
raise RuntimeError(message) from exc
|
||||
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
turn_route = resolve_turn_route(
|
||||
prompt,
|
||||
smart_routing,
|
||||
{
|
||||
"model": model,
|
||||
"api_key": runtime.get("api_key"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
},
|
||||
)
|
||||
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
api_key=runtime.get("api_key"),
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
model=turn_route["model"],
|
||||
api_key=turn_route["runtime"].get("api_key"),
|
||||
base_url=turn_route["runtime"].get("base_url"),
|
||||
provider=turn_route["runtime"].get("provider"),
|
||||
api_mode=turn_route["runtime"].get("api_mode"),
|
||||
acp_command=turn_route["runtime"].get("command"),
|
||||
acp_args=turn_route["runtime"].get("args"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
@@ -344,7 +408,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob"],
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
platform="cron",
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
@@ -353,9 +417,10 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
final_response = result.get("final_response", "") or ""
|
||||
# Use a separate variable for log display; keep final_response clean
|
||||
# for delivery logic (empty response = no delivery).
|
||||
logged_response = final_response if final_response else "(No response generated)"
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
@@ -369,7 +434,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
## Response
|
||||
|
||||
{final_response}
|
||||
{logged_response}
|
||||
"""
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
@@ -465,9 +530,16 @@ def tick(verbose: bool = True) -> int:
|
||||
if verbose:
|
||||
logger.info("Output saved to: %s", output_file)
|
||||
|
||||
# Deliver the final response to the origin/target chat
|
||||
# Deliver the final response to the origin/target chat.
|
||||
# If the agent responded with [SILENT], skip delivery (but
|
||||
# output is already saved above). Failed jobs always deliver.
|
||||
deliver_content = final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
if deliver_content:
|
||||
should_deliver = bool(deliver_content)
|
||||
if should_deliver and success and deliver_content.strip().upper().startswith(SILENT_MARKER):
|
||||
logger.info("Job '%s': agent returned %s — skipping delivery", job["id"], SILENT_MARKER)
|
||||
should_deliver = False
|
||||
|
||||
if should_deliver:
|
||||
try:
|
||||
_deliver_result(job, deliver_content)
|
||||
except Exception as de:
|
||||
|
||||
@@ -0,0 +1,608 @@
|
||||
# Pricing Accuracy Architecture
|
||||
|
||||
Date: 2026-03-16
|
||||
|
||||
## Goal
|
||||
|
||||
Hermes should only show dollar costs when they are backed by an official source for the user's actual billing path.
|
||||
|
||||
This design replaces the current static, heuristic pricing flow in:
|
||||
|
||||
- `run_agent.py`
|
||||
- `agent/usage_pricing.py`
|
||||
- `agent/insights.py`
|
||||
- `cli.py`
|
||||
|
||||
with a provider-aware pricing system that:
|
||||
|
||||
- handles cache billing correctly
|
||||
- distinguishes `actual` vs `estimated` vs `included` vs `unknown`
|
||||
- reconciles post-hoc costs when providers expose authoritative billing data
|
||||
- supports direct providers, OpenRouter, subscriptions, enterprise pricing, and custom endpoints
|
||||
|
||||
## Problems In The Current Design
|
||||
|
||||
Current Hermes behavior has four structural issues:
|
||||
|
||||
1. It stores only `prompt_tokens` and `completion_tokens`, which is insufficient for providers that bill cache reads and cache writes separately.
|
||||
2. It uses a static model price table and fuzzy heuristics, which can drift from current official pricing.
|
||||
3. It assumes public API list pricing matches the user's real billing path.
|
||||
4. It has no distinction between live estimates and reconciled billed cost.
|
||||
|
||||
## Design Principles
|
||||
|
||||
1. Normalize usage before pricing.
|
||||
2. Never fold cached tokens into plain input cost.
|
||||
3. Track certainty explicitly.
|
||||
4. Treat the billing path as part of the model identity.
|
||||
5. Prefer official machine-readable sources over scraped docs.
|
||||
6. Use post-hoc provider cost APIs when available.
|
||||
7. Show `n/a` rather than inventing precision.
|
||||
|
||||
## High-Level Architecture
|
||||
|
||||
The new system has four layers:
|
||||
|
||||
1. `usage_normalization`
|
||||
Converts raw provider usage into a canonical usage record.
|
||||
2. `pricing_source_resolution`
|
||||
Determines the billing path, source of truth, and applicable pricing source.
|
||||
3. `cost_estimation_and_reconciliation`
|
||||
Produces an immediate estimate when possible, then replaces or annotates it with actual billed cost later.
|
||||
4. `presentation`
|
||||
`/usage`, `/insights`, and the status bar display cost with certainty metadata.
|
||||
|
||||
## Canonical Usage Record
|
||||
|
||||
Add a canonical usage model that every provider path maps into before any pricing math happens.
|
||||
|
||||
Suggested structure:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CanonicalUsage:
|
||||
provider: str
|
||||
billing_provider: str
|
||||
model: str
|
||||
billing_route: str
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
request_count: int = 1
|
||||
|
||||
raw_usage: dict[str, Any] | None = None
|
||||
raw_usage_fields: dict[str, str] | None = None
|
||||
computed_fields: set[str] | None = None
|
||||
|
||||
provider_request_id: str | None = None
|
||||
provider_generation_id: str | None = None
|
||||
provider_response_id: str | None = None
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- `input_tokens` means non-cached input only.
|
||||
- `cache_read_tokens` and `cache_write_tokens` are never merged into `input_tokens`.
|
||||
- `output_tokens` excludes cache metrics.
|
||||
- `reasoning_tokens` is telemetry unless a provider officially bills it separately.
|
||||
|
||||
This is the same normalization pattern used by `opencode`, extended with provenance and reconciliation ids.
|
||||
|
||||
## Provider Normalization Rules
|
||||
|
||||
### OpenAI Direct
|
||||
|
||||
Source usage fields:
|
||||
|
||||
- `prompt_tokens`
|
||||
- `completion_tokens`
|
||||
- `prompt_tokens_details.cached_tokens`
|
||||
|
||||
Normalization:
|
||||
|
||||
- `cache_read_tokens = cached_tokens`
|
||||
- `input_tokens = prompt_tokens - cached_tokens`
|
||||
- `cache_write_tokens = 0` unless OpenAI exposes it in the relevant route
|
||||
- `output_tokens = completion_tokens`
|
||||
|
||||
### Anthropic Direct
|
||||
|
||||
Source usage fields:
|
||||
|
||||
- `input_tokens`
|
||||
- `output_tokens`
|
||||
- `cache_read_input_tokens`
|
||||
- `cache_creation_input_tokens`
|
||||
|
||||
Normalization:
|
||||
|
||||
- `input_tokens = input_tokens`
|
||||
- `output_tokens = output_tokens`
|
||||
- `cache_read_tokens = cache_read_input_tokens`
|
||||
- `cache_write_tokens = cache_creation_input_tokens`
|
||||
|
||||
### OpenRouter
|
||||
|
||||
Estimate-time usage normalization should use the response usage payload with the same rules as the underlying provider when possible.
|
||||
|
||||
Reconciliation-time records should also store:
|
||||
|
||||
- OpenRouter generation id
|
||||
- native token fields when available
|
||||
- `total_cost`
|
||||
- `cache_discount`
|
||||
- `upstream_inference_cost`
|
||||
- `is_byok`
|
||||
|
||||
### Gemini / Vertex
|
||||
|
||||
Use official Gemini or Vertex usage fields where available.
|
||||
|
||||
If cached content tokens are exposed:
|
||||
|
||||
- map them to `cache_read_tokens`
|
||||
|
||||
If a route exposes no cache creation metric:
|
||||
|
||||
- store `cache_write_tokens = 0`
|
||||
- preserve the raw usage payload for later extension
|
||||
|
||||
### DeepSeek And Other Direct Providers
|
||||
|
||||
Normalize only the fields that are officially exposed.
|
||||
|
||||
If a provider does not expose cache buckets:
|
||||
|
||||
- do not infer them unless the provider explicitly documents how to derive them
|
||||
|
||||
### Subscription / Included-Cost Routes
|
||||
|
||||
These still use the canonical usage model.
|
||||
|
||||
Tokens are tracked normally. Cost depends on billing mode, not on whether usage exists.
|
||||
|
||||
## Billing Route Model
|
||||
|
||||
Hermes must stop keying pricing solely by `model`.
|
||||
|
||||
Introduce a billing route descriptor:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class BillingRoute:
|
||||
provider: str
|
||||
base_url: str | None
|
||||
model: str
|
||||
billing_mode: str
|
||||
organization_hint: str | None = None
|
||||
```
|
||||
|
||||
`billing_mode` values:
|
||||
|
||||
- `official_cost_api`
|
||||
- `official_generation_api`
|
||||
- `official_models_api`
|
||||
- `official_docs_snapshot`
|
||||
- `subscription_included`
|
||||
- `user_override`
|
||||
- `custom_contract`
|
||||
- `unknown`
|
||||
|
||||
Examples:
|
||||
|
||||
- OpenAI direct API with Costs API access: `official_cost_api`
|
||||
- Anthropic direct API with Usage & Cost API access: `official_cost_api`
|
||||
- OpenRouter request before reconciliation: `official_models_api`
|
||||
- OpenRouter request after generation lookup: `official_generation_api`
|
||||
- GitHub Copilot style subscription route: `subscription_included`
|
||||
- local OpenAI-compatible server: `unknown`
|
||||
- enterprise contract with configured rates: `custom_contract`
|
||||
|
||||
## Cost Status Model
|
||||
|
||||
Every displayed cost should have:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CostResult:
|
||||
amount_usd: Decimal | None
|
||||
status: Literal["actual", "estimated", "included", "unknown"]
|
||||
source: Literal[
|
||||
"provider_cost_api",
|
||||
"provider_generation_api",
|
||||
"provider_models_api",
|
||||
"official_docs_snapshot",
|
||||
"user_override",
|
||||
"custom_contract",
|
||||
"none",
|
||||
]
|
||||
label: str
|
||||
fetched_at: datetime | None
|
||||
pricing_version: str | None
|
||||
notes: list[str]
|
||||
```
|
||||
|
||||
Presentation rules:
|
||||
|
||||
- `actual`: show dollar amount as final
|
||||
- `estimated`: show dollar amount with estimate labeling
|
||||
- `included`: show `included` or `$0.00 (included)` depending on UX choice
|
||||
- `unknown`: show `n/a`
|
||||
|
||||
## Official Source Hierarchy
|
||||
|
||||
Resolve cost using this order:
|
||||
|
||||
1. Request-level or account-level official billed cost
|
||||
2. Official machine-readable model pricing
|
||||
3. Official docs snapshot
|
||||
4. User override or custom contract
|
||||
5. Unknown
|
||||
|
||||
The system must never skip to a lower level if a higher-confidence source exists for the current billing route.
|
||||
|
||||
## Provider-Specific Truth Rules
|
||||
|
||||
### OpenAI Direct
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. Costs API for reconciled spend
|
||||
2. Official pricing page for live estimate
|
||||
|
||||
### Anthropic Direct
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. Usage & Cost API for reconciled spend
|
||||
2. Official pricing docs for live estimate
|
||||
|
||||
### OpenRouter
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. `GET /api/v1/generation` for reconciled `total_cost`
|
||||
2. `GET /api/v1/models` pricing for live estimate
|
||||
|
||||
Do not use underlying provider public pricing as the source of truth for OpenRouter billing.
|
||||
|
||||
### Gemini / Vertex
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. official billing export or billing API for reconciled spend when available for the route
|
||||
2. official pricing docs for estimate
|
||||
|
||||
### DeepSeek
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. official machine-readable cost source if available in the future
|
||||
2. official pricing docs snapshot today
|
||||
|
||||
### Subscription-Included Routes
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. explicit route config marking the model as included in subscription
|
||||
|
||||
These should display `included`, not an API list-price estimate.
|
||||
|
||||
### Custom Endpoint / Local Model
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. user override
|
||||
2. custom contract config
|
||||
3. unknown
|
||||
|
||||
These should default to `unknown`.
|
||||
|
||||
## Pricing Catalog
|
||||
|
||||
Replace the current `MODEL_PRICING` dict with a richer pricing catalog.
|
||||
|
||||
Suggested record:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PricingEntry:
|
||||
provider: str
|
||||
route_pattern: str
|
||||
model_pattern: str
|
||||
|
||||
input_cost_per_million: Decimal | None = None
|
||||
output_cost_per_million: Decimal | None = None
|
||||
cache_read_cost_per_million: Decimal | None = None
|
||||
cache_write_cost_per_million: Decimal | None = None
|
||||
request_cost: Decimal | None = None
|
||||
image_cost: Decimal | None = None
|
||||
|
||||
source: str = "official_docs_snapshot"
|
||||
source_url: str | None = None
|
||||
fetched_at: datetime | None = None
|
||||
pricing_version: str | None = None
|
||||
```
|
||||
|
||||
The catalog should be route-aware:
|
||||
|
||||
- `openai:gpt-5`
|
||||
- `anthropic:claude-opus-4-6`
|
||||
- `openrouter:anthropic/claude-opus-4.6`
|
||||
- `copilot:gpt-4o`
|
||||
|
||||
This avoids conflating direct-provider billing with aggregator billing.
|
||||
|
||||
## Pricing Sync Architecture
|
||||
|
||||
Introduce a pricing sync subsystem instead of manually maintaining a single hardcoded table.
|
||||
|
||||
Suggested modules:
|
||||
|
||||
- `agent/pricing/catalog.py`
|
||||
- `agent/pricing/sources.py`
|
||||
- `agent/pricing/sync.py`
|
||||
- `agent/pricing/reconcile.py`
|
||||
- `agent/pricing/types.py`
|
||||
|
||||
### Sync Sources
|
||||
|
||||
- OpenRouter models API
|
||||
- official provider docs snapshots where no API exists
|
||||
- user overrides from config
|
||||
|
||||
### Sync Output
|
||||
|
||||
Cache pricing entries locally with:
|
||||
|
||||
- source URL
|
||||
- fetch timestamp
|
||||
- version/hash
|
||||
- confidence/source type
|
||||
|
||||
### Sync Frequency
|
||||
|
||||
- startup warm cache
|
||||
- background refresh every 6 to 24 hours depending on source
|
||||
- manual `hermes pricing sync`
|
||||
|
||||
## Reconciliation Architecture
|
||||
|
||||
Live requests may produce only an estimate initially. Hermes should reconcile them later when a provider exposes actual billed cost.
|
||||
|
||||
Suggested flow:
|
||||
|
||||
1. Agent call completes.
|
||||
2. Hermes stores canonical usage plus reconciliation ids.
|
||||
3. Hermes computes an immediate estimate if a pricing source exists.
|
||||
4. A reconciliation worker fetches actual cost when supported.
|
||||
5. Session and message records are updated with `actual` cost.
|
||||
|
||||
This can run:
|
||||
|
||||
- inline for cheap lookups
|
||||
- asynchronously for delayed provider accounting
|
||||
|
||||
## Persistence Changes
|
||||
|
||||
Session storage should stop storing only aggregate prompt/completion totals.
|
||||
|
||||
Add fields for both usage and cost certainty:
|
||||
|
||||
- `input_tokens`
|
||||
- `output_tokens`
|
||||
- `cache_read_tokens`
|
||||
- `cache_write_tokens`
|
||||
- `reasoning_tokens`
|
||||
- `estimated_cost_usd`
|
||||
- `actual_cost_usd`
|
||||
- `cost_status`
|
||||
- `cost_source`
|
||||
- `pricing_version`
|
||||
- `billing_provider`
|
||||
- `billing_mode`
|
||||
|
||||
If schema expansion is too large for one PR, add a new pricing events table:
|
||||
|
||||
```text
|
||||
session_cost_events
|
||||
id
|
||||
session_id
|
||||
request_id
|
||||
provider
|
||||
model
|
||||
billing_mode
|
||||
input_tokens
|
||||
output_tokens
|
||||
cache_read_tokens
|
||||
cache_write_tokens
|
||||
estimated_cost_usd
|
||||
actual_cost_usd
|
||||
cost_status
|
||||
cost_source
|
||||
pricing_version
|
||||
created_at
|
||||
updated_at
|
||||
```
|
||||
|
||||
## Hermes Touchpoints
|
||||
|
||||
### `run_agent.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- parse raw provider usage
|
||||
- update session token counters
|
||||
|
||||
New responsibility:
|
||||
|
||||
- build `CanonicalUsage`
|
||||
- update canonical counters
|
||||
- store reconciliation ids
|
||||
- emit usage event to pricing subsystem
|
||||
|
||||
### `agent/usage_pricing.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- static lookup table
|
||||
- direct cost arithmetic
|
||||
|
||||
New responsibility:
|
||||
|
||||
- move or replace with pricing catalog facade
|
||||
- no fuzzy model-family heuristics
|
||||
- no direct pricing without billing-route context
|
||||
|
||||
### `cli.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- compute session cost directly from prompt/completion totals
|
||||
|
||||
New responsibility:
|
||||
|
||||
- display `CostResult`
|
||||
- show status badges:
|
||||
- `actual`
|
||||
- `estimated`
|
||||
- `included`
|
||||
- `n/a`
|
||||
|
||||
### `agent/insights.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- recompute historical estimates from static pricing
|
||||
|
||||
New responsibility:
|
||||
|
||||
- aggregate stored pricing events
|
||||
- prefer actual cost over estimate
|
||||
- surface estimates only when reconciliation is unavailable
|
||||
|
||||
## UX Rules
|
||||
|
||||
### Status Bar
|
||||
|
||||
Show one of:
|
||||
|
||||
- `$1.42`
|
||||
- `~$1.42`
|
||||
- `included`
|
||||
- `cost n/a`
|
||||
|
||||
Where:
|
||||
|
||||
- `$1.42` means `actual`
|
||||
- `~$1.42` means `estimated`
|
||||
- `included` means subscription-backed or explicitly zero-cost route
|
||||
- `cost n/a` means unknown
|
||||
|
||||
### `/usage`
|
||||
|
||||
Show:
|
||||
|
||||
- token buckets
|
||||
- estimated cost
|
||||
- actual cost if available
|
||||
- cost status
|
||||
- pricing source
|
||||
|
||||
### `/insights`
|
||||
|
||||
Aggregate:
|
||||
|
||||
- actual cost totals
|
||||
- estimated-only totals
|
||||
- unknown-cost sessions count
|
||||
- included-cost sessions count
|
||||
|
||||
## Config And Overrides
|
||||
|
||||
Add user-configurable pricing overrides in config:
|
||||
|
||||
```yaml
|
||||
pricing:
|
||||
mode: hybrid
|
||||
sync_on_startup: true
|
||||
sync_interval_hours: 12
|
||||
overrides:
|
||||
- provider: openrouter
|
||||
model: anthropic/claude-opus-4.6
|
||||
billing_mode: custom_contract
|
||||
input_cost_per_million: 4.25
|
||||
output_cost_per_million: 22.0
|
||||
cache_read_cost_per_million: 0.5
|
||||
cache_write_cost_per_million: 6.0
|
||||
included_routes:
|
||||
- provider: copilot
|
||||
model: "*"
|
||||
- provider: codex-subscription
|
||||
model: "*"
|
||||
```
|
||||
|
||||
Overrides must win over catalog defaults for the matching billing route.
|
||||
|
||||
## Rollout Plan
|
||||
|
||||
### Phase 1
|
||||
|
||||
- add canonical usage model
|
||||
- split cache token buckets in `run_agent.py`
|
||||
- stop pricing cache-inflated prompt totals
|
||||
- preserve current UI with improved backend math
|
||||
|
||||
### Phase 2
|
||||
|
||||
- add route-aware pricing catalog
|
||||
- integrate OpenRouter models API sync
|
||||
- add `estimated` vs `included` vs `unknown`
|
||||
|
||||
### Phase 3
|
||||
|
||||
- add reconciliation for OpenRouter generation cost
|
||||
- add actual cost persistence
|
||||
- update `/insights` to prefer actual cost
|
||||
|
||||
### Phase 4
|
||||
|
||||
- add direct OpenAI and Anthropic reconciliation paths
|
||||
- add user overrides and contract pricing
|
||||
- add pricing sync CLI command
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Add tests for:
|
||||
|
||||
- OpenAI cached token subtraction
|
||||
- Anthropic cache read/write separation
|
||||
- OpenRouter estimated vs actual reconciliation
|
||||
- subscription-backed models showing `included`
|
||||
- custom endpoints showing `n/a`
|
||||
- override precedence
|
||||
- stale catalog fallback behavior
|
||||
|
||||
Current tests that assume heuristic pricing should be replaced with route-aware expectations.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- exact enterprise billing reconstruction without an official source or user override
|
||||
- backfilling perfect historical cost for old sessions that lack cache bucket data
|
||||
- scraping arbitrary provider web pages at request time
|
||||
|
||||
## Recommendation
|
||||
|
||||
Do not expand the existing `MODEL_PRICING` dict.
|
||||
|
||||
That path cannot satisfy the product requirement. Hermes should instead migrate to:
|
||||
|
||||
- canonical usage normalization
|
||||
- route-aware pricing sources
|
||||
- estimate-then-reconcile cost lifecycle
|
||||
- explicit certainty states in the UI
|
||||
|
||||
This is the minimum architecture that makes the statement "Hermes pricing is backed by official sources where possible, and otherwise clearly labeled" defensible.
|
||||
+72
-61
@@ -346,78 +346,89 @@ class HermesAgentLoop:
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments and dispatch
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
except json.JSONDecodeError as e:
|
||||
args = None
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Invalid JSON: {e}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
# Dispatch tool only if arguments parsed successfully
|
||||
if args is not None:
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
|
||||
@@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
@@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
@@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
@@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
@@ -63,7 +63,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email"):
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
|
||||
+304
-39
@@ -32,6 +32,15 @@ def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _normalize_unauthorized_dm_behavior(value: Any, default: str = "pair") -> str:
|
||||
"""Normalize unauthorized DM behavior to a supported value."""
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"pair", "ignore"}:
|
||||
return normalized
|
||||
return default
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
LOCAL = "local"
|
||||
@@ -40,8 +49,14 @@ class Platform(Enum):
|
||||
WHATSAPP = "whatsapp"
|
||||
SLACK = "slack"
|
||||
SIGNAL = "signal"
|
||||
MATTERMOST = "mattermost"
|
||||
MATRIX = "matrix"
|
||||
HOMEASSISTANT = "homeassistant"
|
||||
EMAIL = "email"
|
||||
SMS = "sms"
|
||||
DINGTALK = "dingtalk"
|
||||
API_SERVER = "api_server"
|
||||
WEBHOOK = "webhook"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,23 +101,32 @@ class SessionResetPolicy:
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
notify: bool = True # Send a notification to the user when auto-reset occurs
|
||||
notify_exclude_platforms: tuple = ("api_server", "webhook") # Platforms that don't get reset notifications
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
"notify": self.notify,
|
||||
"notify_exclude_platforms": list(self.notify_exclude_platforms),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
# Handle both missing keys and explicit null values (YAML null → None)
|
||||
mode = data.get("mode")
|
||||
at_hour = data.get("at_hour")
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
notify = data.get("notify")
|
||||
exclude = data.get("notify_exclude_platforms")
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
mode=mode if mode is not None else "both",
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
|
||||
notify=notify if notify is not None else True,
|
||||
notify_exclude_platforms=tuple(exclude) if exclude is not None else ("api_server", "webhook"),
|
||||
)
|
||||
|
||||
|
||||
@@ -145,6 +169,37 @@ class PlatformConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingConfig:
|
||||
"""Configuration for real-time token streaming to messaging platforms."""
|
||||
enabled: bool = False
|
||||
transport: str = "edit" # "edit" (progressive editMessageText) or "off"
|
||||
edit_interval: float = 0.3 # Seconds between message edits
|
||||
buffer_threshold: int = 40 # Chars before forcing an edit
|
||||
cursor: str = " ▉" # Cursor shown during streaming
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"transport": self.transport,
|
||||
"edit_interval": self.edit_interval,
|
||||
"buffer_threshold": self.buffer_threshold,
|
||||
"cursor": self.cursor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StreamingConfig":
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
transport=data.get("transport", "edit"),
|
||||
edit_interval=float(data.get("edit_interval", 0.3)),
|
||||
buffer_threshold=int(data.get("buffer_threshold", 40)),
|
||||
cursor=data.get("cursor", " ▉"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayConfig:
|
||||
"""
|
||||
@@ -174,7 +229,16 @@ class GatewayConfig:
|
||||
|
||||
# STT settings
|
||||
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
|
||||
|
||||
|
||||
# Session isolation in shared chats
|
||||
group_sessions_per_user: bool = True # Isolate group/channel sessions per participant when user IDs are available
|
||||
|
||||
# Unauthorized DM policy
|
||||
unauthorized_dm_behavior: str = "pair" # "pair" or "ignore"
|
||||
|
||||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
@@ -193,6 +257,15 @@ class GatewayConfig:
|
||||
# Email uses extra dict for config (address + imap_host + smtp_host)
|
||||
elif platform == Platform.EMAIL and config.extra.get("address"):
|
||||
connected.append(platform)
|
||||
# SMS uses api_key (Twilio auth token) — SID checked via env
|
||||
elif platform == Platform.SMS and os.getenv("TWILIO_ACCOUNT_SID"):
|
||||
connected.append(platform)
|
||||
# API Server uses enabled flag only (no token needed)
|
||||
elif platform == Platform.API_SERVER:
|
||||
connected.append(platform)
|
||||
# Webhook uses enabled flag only (secrets are per-route)
|
||||
elif platform == Platform.WEBHOOK:
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
@@ -239,6 +312,9 @@ class GatewayConfig:
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
"stt_enabled": self.stt_enabled,
|
||||
"group_sessions_per_user": self.group_sessions_per_user,
|
||||
"unauthorized_dm_behavior": self.unauthorized_dm_behavior,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -279,6 +355,12 @@ class GatewayConfig:
|
||||
if stt_enabled is None:
|
||||
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
|
||||
|
||||
group_sessions_per_user = data.get("group_sessions_per_user")
|
||||
unauthorized_dm_behavior = _normalize_unauthorized_dm_behavior(
|
||||
data.get("unauthorized_dm_behavior"),
|
||||
"pair",
|
||||
)
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -289,63 +371,147 @@ class GatewayConfig:
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||
group_sessions_per_user=_coerce_bool(group_sessions_per_user, True),
|
||||
unauthorized_dm_behavior=unauthorized_dm_behavior,
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
)
|
||||
|
||||
def get_unauthorized_dm_behavior(self, platform: Optional[Platform] = None) -> str:
|
||||
"""Return the effective unauthorized-DM behavior for a platform."""
|
||||
if platform:
|
||||
platform_cfg = self.platforms.get(platform)
|
||||
if platform_cfg and "unauthorized_dm_behavior" in platform_cfg.extra:
|
||||
return _normalize_unauthorized_dm_behavior(
|
||||
platform_cfg.extra.get("unauthorized_dm_behavior"),
|
||||
self.unauthorized_dm_behavior,
|
||||
)
|
||||
return self.unauthorized_dm_behavior
|
||||
|
||||
|
||||
def load_gateway_config() -> GatewayConfig:
|
||||
"""
|
||||
Load gateway configuration from multiple sources.
|
||||
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. ~/.hermes/gateway.json
|
||||
3. cli-config.yaml gateway section
|
||||
4. Defaults
|
||||
2. ~/.hermes/config.yaml (primary user-facing config)
|
||||
3. ~/.hermes/gateway.json (legacy — provides defaults under config.yaml)
|
||||
4. Built-in defaults
|
||||
"""
|
||||
config = GatewayConfig()
|
||||
|
||||
# Try loading from ~/.hermes/gateway.json
|
||||
_home = get_hermes_home()
|
||||
gateway_config_path = _home / "gateway.json"
|
||||
if gateway_config_path.exists():
|
||||
try:
|
||||
with open(gateway_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
|
||||
gw_data: dict = {}
|
||||
|
||||
# Bridge session_reset from config.yaml (the user-facing config file)
|
||||
# into the gateway config. config.yaml takes precedence over gateway.json
|
||||
# for session reset policy since that's where hermes setup writes it.
|
||||
# Legacy fallback: gateway.json provides the base layer.
|
||||
# config.yaml keys always win when both specify the same setting.
|
||||
gateway_json_path = _home / "gateway.json"
|
||||
if gateway_json_path.exists():
|
||||
try:
|
||||
with open(gateway_json_path, "r", encoding="utf-8") as f:
|
||||
gw_data = json.load(f) or {}
|
||||
logger.info(
|
||||
"Loaded legacy %s — consider moving settings to config.yaml",
|
||||
gateway_json_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load %s: %s", gateway_json_path, e)
|
||||
|
||||
# Primary source: config.yaml
|
||||
try:
|
||||
import yaml
|
||||
config_yaml_path = _home / "config.yaml"
|
||||
if config_yaml_path.exists():
|
||||
with open(config_yaml_path, encoding="utf-8") as f:
|
||||
yaml_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
# Map config.yaml keys → GatewayConfig.from_dict() schema.
|
||||
# Each key overwrites whatever gateway.json may have set.
|
||||
sr = yaml_cfg.get("session_reset")
|
||||
if sr and isinstance(sr, dict):
|
||||
config.default_reset_policy = SessionResetPolicy.from_dict(sr)
|
||||
gw_data["default_reset_policy"] = sr
|
||||
|
||||
# Bridge quick commands from config.yaml into gateway runtime config.
|
||||
# config.yaml is the user-facing config source, so when present it
|
||||
# should override gateway.json for this setting.
|
||||
qc = yaml_cfg.get("quick_commands")
|
||||
if qc is not None:
|
||||
if isinstance(qc, dict):
|
||||
config.quick_commands = qc
|
||||
gw_data["quick_commands"] = qc
|
||||
else:
|
||||
logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__)
|
||||
logger.warning(
|
||||
"Ignoring invalid quick_commands in config.yaml "
|
||||
"(expected mapping, got %s)",
|
||||
type(qc).__name__,
|
||||
)
|
||||
|
||||
# Bridge STT enable/disable from config.yaml into gateway runtime.
|
||||
# This keeps the gateway aligned with the user-facing config source.
|
||||
stt_cfg = yaml_cfg.get("stt")
|
||||
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
|
||||
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
|
||||
if isinstance(stt_cfg, dict):
|
||||
gw_data["stt"] = stt_cfg
|
||||
|
||||
# Bridge discord settings from config.yaml to env vars
|
||||
# (env vars take precedence — only set if not already defined)
|
||||
if "group_sessions_per_user" in yaml_cfg:
|
||||
gw_data["group_sessions_per_user"] = yaml_cfg["group_sessions_per_user"]
|
||||
|
||||
streaming_cfg = yaml_cfg.get("streaming")
|
||||
if isinstance(streaming_cfg, dict):
|
||||
gw_data["streaming"] = streaming_cfg
|
||||
|
||||
if "reset_triggers" in yaml_cfg:
|
||||
gw_data["reset_triggers"] = yaml_cfg["reset_triggers"]
|
||||
|
||||
if "always_log_local" in yaml_cfg:
|
||||
gw_data["always_log_local"] = yaml_cfg["always_log_local"]
|
||||
|
||||
if "unauthorized_dm_behavior" in yaml_cfg:
|
||||
gw_data["unauthorized_dm_behavior"] = _normalize_unauthorized_dm_behavior(
|
||||
yaml_cfg.get("unauthorized_dm_behavior"),
|
||||
"pair",
|
||||
)
|
||||
|
||||
# Merge platforms section from config.yaml into gw_data so that
|
||||
# nested keys like platforms.webhook.extra.routes are loaded.
|
||||
yaml_platforms = yaml_cfg.get("platforms")
|
||||
platforms_data = gw_data.setdefault("platforms", {})
|
||||
if not isinstance(platforms_data, dict):
|
||||
platforms_data = {}
|
||||
gw_data["platforms"] = platforms_data
|
||||
if isinstance(yaml_platforms, dict):
|
||||
for plat_name, plat_block in yaml_platforms.items():
|
||||
if not isinstance(plat_block, dict):
|
||||
continue
|
||||
existing = platforms_data.get(plat_name, {})
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
# Deep-merge extra dicts so gateway.json defaults survive
|
||||
merged_extra = {**existing.get("extra", {}), **plat_block.get("extra", {})}
|
||||
merged = {**existing, **plat_block}
|
||||
if merged_extra:
|
||||
merged["extra"] = merged_extra
|
||||
platforms_data[plat_name] = merged
|
||||
gw_data["platforms"] = platforms_data
|
||||
for plat in Platform:
|
||||
if plat == Platform.LOCAL:
|
||||
continue
|
||||
platform_cfg = yaml_cfg.get(plat.value)
|
||||
if not isinstance(platform_cfg, dict):
|
||||
continue
|
||||
# Collect bridgeable keys from this platform section
|
||||
bridged = {}
|
||||
if "unauthorized_dm_behavior" in platform_cfg:
|
||||
bridged["unauthorized_dm_behavior"] = _normalize_unauthorized_dm_behavior(
|
||||
platform_cfg.get("unauthorized_dm_behavior"),
|
||||
gw_data.get("unauthorized_dm_behavior", "pair"),
|
||||
)
|
||||
if "reply_prefix" in platform_cfg:
|
||||
bridged["reply_prefix"] = platform_cfg["reply_prefix"]
|
||||
if not bridged:
|
||||
continue
|
||||
plat_data = platforms_data.setdefault(plat.value, {})
|
||||
if not isinstance(plat_data, dict):
|
||||
plat_data = {}
|
||||
platforms_data[plat.value] = plat_data
|
||||
extra = plat_data.setdefault("extra", {})
|
||||
if not isinstance(extra, dict):
|
||||
extra = {}
|
||||
plat_data["extra"] = extra
|
||||
extra.update(bridged)
|
||||
|
||||
# Discord settings → env vars (env vars take precedence)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
if isinstance(discord_cfg, dict):
|
||||
if "require_mention" in discord_cfg and not os.getenv("DISCORD_REQUIRE_MENTION"):
|
||||
@@ -360,6 +526,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = GatewayConfig.from_dict(gw_data)
|
||||
|
||||
# Override with environment variables
|
||||
_apply_env_overrides(config)
|
||||
|
||||
@@ -385,6 +553,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
Platform.TELEGRAM: "TELEGRAM_BOT_TOKEN",
|
||||
Platform.DISCORD: "DISCORD_BOT_TOKEN",
|
||||
Platform.SLACK: "SLACK_BOT_TOKEN",
|
||||
Platform.MATTERMOST: "MATTERMOST_TOKEN",
|
||||
Platform.MATRIX: "MATRIX_ACCESS_TOKEN",
|
||||
}
|
||||
for platform, pconfig in config.platforms.items():
|
||||
if not pconfig.enabled:
|
||||
@@ -478,6 +648,53 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Mattermost
|
||||
mattermost_token = os.getenv("MATTERMOST_TOKEN")
|
||||
if mattermost_token:
|
||||
mattermost_url = os.getenv("MATTERMOST_URL", "")
|
||||
if not mattermost_url:
|
||||
logger.warning("MATTERMOST_TOKEN set but MATTERMOST_URL is missing")
|
||||
if Platform.MATTERMOST not in config.platforms:
|
||||
config.platforms[Platform.MATTERMOST] = PlatformConfig()
|
||||
config.platforms[Platform.MATTERMOST].enabled = True
|
||||
config.platforms[Platform.MATTERMOST].token = mattermost_token
|
||||
config.platforms[Platform.MATTERMOST].extra["url"] = mattermost_url
|
||||
mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL")
|
||||
if mattermost_home:
|
||||
config.platforms[Platform.MATTERMOST].home_channel = HomeChannel(
|
||||
platform=Platform.MATTERMOST,
|
||||
chat_id=mattermost_home,
|
||||
name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Matrix
|
||||
matrix_token = os.getenv("MATRIX_ACCESS_TOKEN")
|
||||
matrix_homeserver = os.getenv("MATRIX_HOMESERVER", "")
|
||||
if matrix_token or os.getenv("MATRIX_PASSWORD"):
|
||||
if not matrix_homeserver:
|
||||
logger.warning("MATRIX_ACCESS_TOKEN/MATRIX_PASSWORD set but MATRIX_HOMESERVER is missing")
|
||||
if Platform.MATRIX not in config.platforms:
|
||||
config.platforms[Platform.MATRIX] = PlatformConfig()
|
||||
config.platforms[Platform.MATRIX].enabled = True
|
||||
if matrix_token:
|
||||
config.platforms[Platform.MATRIX].token = matrix_token
|
||||
config.platforms[Platform.MATRIX].extra["homeserver"] = matrix_homeserver
|
||||
matrix_user = os.getenv("MATRIX_USER_ID", "")
|
||||
if matrix_user:
|
||||
config.platforms[Platform.MATRIX].extra["user_id"] = matrix_user
|
||||
matrix_password = os.getenv("MATRIX_PASSWORD", "")
|
||||
if matrix_password:
|
||||
config.platforms[Platform.MATRIX].extra["password"] = matrix_password
|
||||
matrix_e2ee = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes")
|
||||
config.platforms[Platform.MATRIX].extra["encryption"] = matrix_e2ee
|
||||
matrix_home = os.getenv("MATRIX_HOME_ROOM")
|
||||
if matrix_home:
|
||||
config.platforms[Platform.MATRIX].home_channel = HomeChannel(
|
||||
platform=Platform.MATRIX,
|
||||
chat_id=matrix_home,
|
||||
name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Home Assistant
|
||||
hass_token = os.getenv("HASS_TOKEN")
|
||||
if hass_token:
|
||||
@@ -511,6 +728,61 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"),
|
||||
)
|
||||
|
||||
# SMS (Twilio)
|
||||
twilio_sid = os.getenv("TWILIO_ACCOUNT_SID")
|
||||
if twilio_sid:
|
||||
if Platform.SMS not in config.platforms:
|
||||
config.platforms[Platform.SMS] = PlatformConfig()
|
||||
config.platforms[Platform.SMS].enabled = True
|
||||
config.platforms[Platform.SMS].api_key = os.getenv("TWILIO_AUTH_TOKEN", "")
|
||||
sms_home = os.getenv("SMS_HOME_CHANNEL")
|
||||
if sms_home:
|
||||
config.platforms[Platform.SMS].home_channel = HomeChannel(
|
||||
platform=Platform.SMS,
|
||||
chat_id=sms_home,
|
||||
name=os.getenv("SMS_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# API Server
|
||||
api_server_enabled = os.getenv("API_SERVER_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
api_server_key = os.getenv("API_SERVER_KEY", "")
|
||||
api_server_cors_origins = os.getenv("API_SERVER_CORS_ORIGINS", "")
|
||||
api_server_port = os.getenv("API_SERVER_PORT")
|
||||
api_server_host = os.getenv("API_SERVER_HOST")
|
||||
if api_server_enabled or api_server_key:
|
||||
if Platform.API_SERVER not in config.platforms:
|
||||
config.platforms[Platform.API_SERVER] = PlatformConfig()
|
||||
config.platforms[Platform.API_SERVER].enabled = True
|
||||
if api_server_key:
|
||||
config.platforms[Platform.API_SERVER].extra["key"] = api_server_key
|
||||
if api_server_cors_origins:
|
||||
origins = [origin.strip() for origin in api_server_cors_origins.split(",") if origin.strip()]
|
||||
if origins:
|
||||
config.platforms[Platform.API_SERVER].extra["cors_origins"] = origins
|
||||
if api_server_port:
|
||||
try:
|
||||
config.platforms[Platform.API_SERVER].extra["port"] = int(api_server_port)
|
||||
except ValueError:
|
||||
pass
|
||||
if api_server_host:
|
||||
config.platforms[Platform.API_SERVER].extra["host"] = api_server_host
|
||||
|
||||
# Webhook platform
|
||||
webhook_enabled = os.getenv("WEBHOOK_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
webhook_port = os.getenv("WEBHOOK_PORT")
|
||||
webhook_secret = os.getenv("WEBHOOK_SECRET", "")
|
||||
if webhook_enabled:
|
||||
if Platform.WEBHOOK not in config.platforms:
|
||||
config.platforms[Platform.WEBHOOK] = PlatformConfig()
|
||||
config.platforms[Platform.WEBHOOK].enabled = True
|
||||
if webhook_port:
|
||||
try:
|
||||
config.platforms[Platform.WEBHOOK].extra["port"] = int(webhook_port)
|
||||
except ValueError:
|
||||
pass
|
||||
if webhook_secret:
|
||||
config.platforms[Platform.WEBHOOK].extra["secret"] = webhook_secret
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
@@ -527,10 +799,3 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def save_gateway_config(config: GatewayConfig) -> None:
|
||||
"""Save gateway configuration to ~/.hermes/gateway.json."""
|
||||
gateway_config_path = get_hermes_home() / "gateway.json"
|
||||
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(gateway_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
|
||||
+3
-2
@@ -8,8 +8,9 @@ Hooks are discovered from ~/.hermes/hooks/ directories, each containing:
|
||||
|
||||
Events:
|
||||
- gateway:startup -- Gateway process starts
|
||||
- session:start -- New session created
|
||||
- session:reset -- User ran /new or /reset
|
||||
- session:start -- New session created (first message of a new session)
|
||||
- session:end -- Session ends (user ran /new or /reset)
|
||||
- session:reset -- Session reset completed (new session entry created)
|
||||
- agent:start -- Agent begins processing a message
|
||||
- agent:step -- Each turn in the tool-calling loop
|
||||
- agent:end -- Agent finishes processing
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+163
-7
@@ -294,6 +294,7 @@ class MessageEvent:
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: Optional[str] = None
|
||||
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -503,6 +504,14 @@ class BasePlatformAdapter(ABC):
|
||||
metadata: optional dict with platform-specific context (e.g. thread_id for Slack).
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop a persistent typing indicator (if the platform uses one).
|
||||
|
||||
Override in subclasses that start background typing loops.
|
||||
Default is a no-op for platforms with one-shot typing indicators.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -510,6 +519,7 @@ class BasePlatformAdapter(ABC):
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
@@ -528,6 +538,7 @@ class BasePlatformAdapter(ABC):
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an animated GIF natively via the platform API.
|
||||
@@ -536,7 +547,7 @@ class BasePlatformAdapter(ABC):
|
||||
(e.g., Telegram send_animation) so they auto-play inline.
|
||||
Default falls back to send_image.
|
||||
"""
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def _is_animation_url(url: str) -> bool:
|
||||
@@ -710,7 +721,7 @@ class BasePlatformAdapter(ABC):
|
||||
# Extract MEDIA:<path> tags, allowing optional whitespace after the colon
|
||||
# and quoted/backticked paths for LLM-formatted outputs.
|
||||
media_pattern = re.compile(
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|\S+)[`"']?'''
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?'''
|
||||
)
|
||||
for match in media_pattern.finditer(content):
|
||||
path = match.group("path").strip()
|
||||
@@ -726,7 +737,75 @@ class BasePlatformAdapter(ABC):
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_local_files(content: str) -> Tuple[List[str], str]:
|
||||
"""
|
||||
Detect bare local file paths in response text for native media delivery.
|
||||
|
||||
Matches absolute paths (/...) and tilde paths (~/) ending in common
|
||||
image or video extensions. Validates each candidate with
|
||||
``os.path.isfile()`` to avoid false positives from URLs or
|
||||
non-existent paths.
|
||||
|
||||
Paths inside fenced code blocks (``` ... ```) and inline code
|
||||
(`...`) are ignored so that code samples are never mutilated.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of expanded file paths, cleaned text with the
|
||||
raw path strings removed).
|
||||
"""
|
||||
_LOCAL_MEDIA_EXTS = (
|
||||
'.png', '.jpg', '.jpeg', '.gif', '.webp',
|
||||
'.mp4', '.mov', '.avi', '.mkv', '.webm',
|
||||
)
|
||||
ext_part = '|'.join(e.lstrip('.') for e in _LOCAL_MEDIA_EXTS)
|
||||
|
||||
# (?<![/:\w.]) prevents matching inside URLs (e.g. https://…/img.png)
|
||||
# and relative paths (./foo.png)
|
||||
# (?:~/|/) anchors to absolute or home-relative paths
|
||||
path_re = re.compile(
|
||||
r'(?<![/:\w.])(?:~/|/)(?:[\w.\-]+/)*[\w.\-]+\.(?:' + ext_part + r')\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Build spans covered by fenced code blocks and inline code
|
||||
code_spans: list = []
|
||||
for m in re.finditer(r'```[^\n]*\n.*?```', content, re.DOTALL):
|
||||
code_spans.append((m.start(), m.end()))
|
||||
for m in re.finditer(r'`[^`\n]+`', content):
|
||||
code_spans.append((m.start(), m.end()))
|
||||
|
||||
def _in_code(pos: int) -> bool:
|
||||
return any(s <= pos < e for s, e in code_spans)
|
||||
|
||||
found: list = [] # (raw_match_text, expanded_path)
|
||||
for match in path_re.finditer(content):
|
||||
if _in_code(match.start()):
|
||||
continue
|
||||
raw = match.group(0)
|
||||
expanded = os.path.expanduser(raw)
|
||||
if os.path.isfile(expanded):
|
||||
found.append((raw, expanded))
|
||||
|
||||
# Deduplicate by expanded path, preserving discovery order
|
||||
seen: set = set()
|
||||
unique: list = []
|
||||
for raw, expanded in found:
|
||||
if expanded not in seen:
|
||||
seen.add(expanded)
|
||||
unique.append((raw, expanded))
|
||||
|
||||
paths = [expanded for _, expanded in unique]
|
||||
|
||||
cleaned = content
|
||||
if unique:
|
||||
for raw, _exp in unique:
|
||||
cleaned = cleaned.replace(raw, '')
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return paths, cleaned
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0, metadata=None) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
@@ -752,7 +831,10 @@ class BasePlatformAdapter(ABC):
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
)
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
@@ -836,8 +918,17 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
# Strip any remaining internal directives from message body (fixes #1561)
|
||||
text_content = text_content.replace("[[audio_as_voice]]", "").strip()
|
||||
text_content = re.sub(r"MEDIA:\s*\S+", "", text_content).strip()
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Auto-detect bare local file paths for native media delivery
|
||||
# (helps small models that don't use MEDIA: syntax)
|
||||
local_files, text_content = self.extract_local_files(text_content)
|
||||
if local_files:
|
||||
logger.info("[%s] extract_local_files found %d file(s) in response", self.name, len(local_files))
|
||||
|
||||
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
||||
# Skipped when the chat has voice mode disabled (/voice off)
|
||||
@@ -931,7 +1022,7 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'}
|
||||
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
@@ -968,7 +1059,34 @@ class BasePlatformAdapter(ABC):
|
||||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
|
||||
|
||||
# Send auto-detected local files as native attachments
|
||||
for file_path in local_files:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
await self.send_image_file(
|
||||
chat_id=event.source.chat_id,
|
||||
image_path=file_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
await self.send_video(
|
||||
chat_id=event.source.chat_id,
|
||||
video_path=file_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
else:
|
||||
await self.send_document(
|
||||
chat_id=event.source.chat_id,
|
||||
file_path=file_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
except Exception as file_err:
|
||||
logger.error("[%s] Error sending local file %s: %s", self.name, file_path, file_err)
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
@@ -989,6 +1107,22 @@ class BasePlatformAdapter(ABC):
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Send the error to the user so they aren't left with radio silence
|
||||
try:
|
||||
error_type = type(e).__name__
|
||||
error_detail = str(e)[:300] if str(e) else "no details available"
|
||||
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=(
|
||||
f"Sorry, I encountered an error ({error_type}).\n"
|
||||
f"{error_detail}\n"
|
||||
"Try again or use /reset to start a fresh session."
|
||||
),
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
except Exception:
|
||||
pass # Last resort — don't let error reporting crash the handler
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
typing_task.cancel()
|
||||
@@ -1074,7 +1208,8 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
||||
@staticmethod
|
||||
def truncate_message(content: str, max_length: int = 4096) -> List[str]:
|
||||
"""
|
||||
Split a long message into chunks, preserving code block boundaries.
|
||||
|
||||
@@ -1126,6 +1261,27 @@ class BasePlatformAdapter(ABC):
|
||||
if split_at < 1:
|
||||
split_at = headroom
|
||||
|
||||
# Avoid splitting inside an inline code span (`...`).
|
||||
# If the text before split_at has an odd number of unescaped
|
||||
# backticks, the split falls inside inline code — the resulting
|
||||
# chunk would have an unpaired backtick and any special characters
|
||||
# (like parentheses) inside the broken span would be unescaped,
|
||||
# causing MarkdownV2 parse errors on Telegram.
|
||||
candidate = remaining[:split_at]
|
||||
backtick_count = candidate.count("`") - candidate.count("\\`")
|
||||
if backtick_count % 2 == 1:
|
||||
# Find the last unescaped backtick and split before it
|
||||
last_bt = candidate.rfind("`")
|
||||
while last_bt > 0 and candidate[last_bt - 1] == "\\":
|
||||
last_bt = candidate.rfind("`", 0, last_bt)
|
||||
if last_bt > 0:
|
||||
# Try to find a space or newline just before the backtick
|
||||
safe_split = candidate.rfind(" ", 0, last_bt)
|
||||
nl_split = candidate.rfind("\n", 0, last_bt)
|
||||
safe_split = max(safe_split, nl_split)
|
||||
if safe_split > headroom // 4:
|
||||
split_at = safe_split
|
||||
|
||||
chunk_body = remaining[:split_at]
|
||||
remaining = remaining[split_at:].lstrip()
|
||||
|
||||
|
||||
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
DingTalk platform adapter using Stream Mode.
|
||||
|
||||
Uses dingtalk-stream SDK for real-time message reception without webhooks.
|
||||
Responses are sent via DingTalk's session webhook (markdown format).
|
||||
|
||||
Requires:
|
||||
pip install dingtalk-stream httpx
|
||||
DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET env vars
|
||||
|
||||
Configuration in config.yaml:
|
||||
platforms:
|
||||
dingtalk:
|
||||
enabled: true
|
||||
extra:
|
||||
client_id: "your-app-key" # or DINGTALK_CLIENT_ID env var
|
||||
client_secret: "your-secret" # or DINGTALK_CLIENT_SECRET env var
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
try:
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import ChatbotHandler, ChatbotMessage
|
||||
DINGTALK_STREAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
DINGTALK_STREAM_AVAILABLE = False
|
||||
dingtalk_stream = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import httpx
|
||||
HTTPX_AVAILABLE = True
|
||||
except ImportError:
|
||||
HTTPX_AVAILABLE = False
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_LENGTH = 20000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
|
||||
|
||||
def check_dingtalk_requirements() -> bool:
|
||||
"""Check if DingTalk dependencies are available and configured."""
|
||||
if not DINGTALK_STREAM_AVAILABLE or not HTTPX_AVAILABLE:
|
||||
return False
|
||||
if not os.getenv("DINGTALK_CLIENT_ID") or not os.getenv("DINGTALK_CLIENT_SECRET"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DingTalkAdapter(BasePlatformAdapter):
|
||||
"""DingTalk chatbot adapter using Stream Mode.
|
||||
|
||||
The dingtalk-stream SDK maintains a long-lived WebSocket connection.
|
||||
Incoming messages arrive via a ChatbotHandler callback. Replies are
|
||||
sent via the incoming message's session_webhook URL using httpx.
|
||||
"""
|
||||
|
||||
MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DINGTALK)
|
||||
|
||||
extra = config.extra or {}
|
||||
self._client_id: str = extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID", "")
|
||||
self._client_secret: str = extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET", "")
|
||||
|
||||
self._stream_client: Any = None
|
||||
self._stream_task: Optional[asyncio.Task] = None
|
||||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||||
|
||||
# Message deduplication: msg_id -> timestamp
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
# Map chat_id -> session_webhook for reply routing
|
||||
self._session_webhooks: Dict[str, str] = {}
|
||||
|
||||
# -- Connection lifecycle -----------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to DingTalk via Stream Mode."""
|
||||
if not DINGTALK_STREAM_AVAILABLE:
|
||||
logger.warning("[%s] dingtalk-stream not installed. Run: pip install dingtalk-stream", self.name)
|
||||
return False
|
||||
if not HTTPX_AVAILABLE:
|
||||
logger.warning("[%s] httpx not installed. Run: pip install httpx", self.name)
|
||||
return False
|
||||
if not self._client_id or not self._client_secret:
|
||||
logger.warning("[%s] DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET required", self.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._http_client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
credential = dingtalk_stream.Credential(self._client_id, self._client_secret)
|
||||
self._stream_client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
|
||||
# Capture the current event loop for cross-thread dispatch
|
||||
loop = asyncio.get_running_loop()
|
||||
handler = _IncomingHandler(self, loop)
|
||||
self._stream_client.register_callback_handler(
|
||||
dingtalk_stream.ChatbotMessage.TOPIC, handler
|
||||
)
|
||||
|
||||
self._stream_task = asyncio.create_task(self._run_stream())
|
||||
self._mark_connected()
|
||||
logger.info("[%s] Connected via Stream Mode", self.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to connect: %s", self.name, e)
|
||||
return False
|
||||
|
||||
async def _run_stream(self) -> None:
|
||||
"""Run the blocking stream client with auto-reconnection."""
|
||||
backoff_idx = 0
|
||||
while self._running:
|
||||
try:
|
||||
logger.debug("[%s] Starting stream client...", self.name)
|
||||
await asyncio.to_thread(self._stream_client.start)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
return
|
||||
logger.warning("[%s] Stream client error: %s", self.name, e)
|
||||
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)]
|
||||
logger.info("[%s] Reconnecting in %ds...", self.name, delay)
|
||||
await asyncio.sleep(delay)
|
||||
backoff_idx += 1
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from DingTalk."""
|
||||
self._running = False
|
||||
self._mark_disconnected()
|
||||
|
||||
if self._stream_task:
|
||||
self._stream_task.cancel()
|
||||
try:
|
||||
await self._stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._stream_task = None
|
||||
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
self._stream_client = None
|
||||
self._session_webhooks.clear()
|
||||
self._seen_messages.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# -- Inbound message processing -----------------------------------------
|
||||
|
||||
async def _on_message(self, message: "ChatbotMessage") -> None:
|
||||
"""Process an incoming DingTalk chatbot message."""
|
||||
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
|
||||
if self._is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
|
||||
return
|
||||
|
||||
text = self._extract_text(message)
|
||||
if not text:
|
||||
logger.debug("[%s] Empty message, skipping", self.name)
|
||||
return
|
||||
|
||||
# Chat context
|
||||
conversation_id = getattr(message, "conversation_id", "") or ""
|
||||
conversation_type = getattr(message, "conversation_type", "1")
|
||||
is_group = str(conversation_type) == "2"
|
||||
sender_id = getattr(message, "sender_id", "") or ""
|
||||
sender_nick = getattr(message, "sender_nick", "") or sender_id
|
||||
sender_staff_id = getattr(message, "sender_staff_id", "") or ""
|
||||
|
||||
chat_id = conversation_id or sender_id
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Store session webhook for reply routing
|
||||
session_webhook = getattr(message, "session_webhook", None) or ""
|
||||
if session_webhook and chat_id:
|
||||
self._session_webhooks[chat_id] = session_webhook
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=chat_id,
|
||||
chat_name=getattr(message, "conversation_title", None),
|
||||
chat_type=chat_type,
|
||||
user_id=sender_id,
|
||||
user_name=sender_nick,
|
||||
user_id_alt=sender_staff_id if sender_staff_id else None,
|
||||
)
|
||||
|
||||
# Parse timestamp
|
||||
create_at = getattr(message, "create_at", None)
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(int(create_at) / 1000, tz=timezone.utc) if create_at else datetime.now(tz=timezone.utc)
|
||||
except (ValueError, OSError, TypeError):
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id=msg_id,
|
||||
raw_message=message,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("[%s] Message from %s in %s: %s",
|
||||
self.name, sender_nick, chat_id[:20] if chat_id else "?", text[:50])
|
||||
await self.handle_message(event)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(message: "ChatbotMessage") -> str:
|
||||
"""Extract plain text from a DingTalk chatbot message."""
|
||||
text = getattr(message, "text", None) or ""
|
||||
if isinstance(text, dict):
|
||||
content = text.get("content", "").strip()
|
||||
else:
|
||||
content = str(text).strip()
|
||||
|
||||
# Fall back to rich text if present
|
||||
if not content:
|
||||
rich_text = getattr(message, "rich_text", None)
|
||||
if rich_text and isinstance(rich_text, list):
|
||||
parts = [item["text"] for item in rich_text
|
||||
if isinstance(item, dict) and item.get("text")]
|
||||
content = " ".join(parts).strip()
|
||||
return content
|
||||
|
||||
# -- Deduplication ------------------------------------------------------
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Check and record a message ID. Returns True if already seen."""
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
# -- Outbound messaging -------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a markdown reply via DingTalk session webhook."""
|
||||
metadata = metadata or {}
|
||||
|
||||
session_webhook = metadata.get("session_webhook") or self._session_webhooks.get(chat_id)
|
||||
if not session_webhook:
|
||||
return SendResult(success=False,
|
||||
error="No session_webhook available. Reply must follow an incoming message.")
|
||||
|
||||
if not self._http_client:
|
||||
return SendResult(success=False, error="HTTP client not initialized")
|
||||
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"title": "Hermes", "text": content[:self.MAX_MESSAGE_LENGTH]},
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http_client.post(session_webhook, json=payload, timeout=15.0)
|
||||
if resp.status_code < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
body = resp.text
|
||||
logger.warning("[%s] Send failed HTTP %d: %s", self.name, resp.status_code, body[:200])
|
||||
return SendResult(success=False, error=f"HTTP {resp.status_code}: {body[:200]}")
|
||||
except httpx.TimeoutException:
|
||||
return SendResult(success=False, error="Timeout sending message to DingTalk")
|
||||
except Exception as e:
|
||||
logger.error("[%s] Send error: %s", self.name, e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""DingTalk does not support typing indicators."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return basic info about a DingTalk conversation."""
|
||||
return {"name": chat_id, "type": "group" if "group" in chat_id.lower() else "dm"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal stream handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _IncomingHandler(ChatbotHandler if DINGTALK_STREAM_AVAILABLE else object):
|
||||
"""dingtalk-stream ChatbotHandler that forwards messages to the adapter."""
|
||||
|
||||
def __init__(self, adapter: DingTalkAdapter, loop: asyncio.AbstractEventLoop):
|
||||
if DINGTALK_STREAM_AVAILABLE:
|
||||
super().__init__()
|
||||
self._adapter = adapter
|
||||
self._loop = loop
|
||||
|
||||
def process(self, message: "ChatbotMessage"):
|
||||
"""Called by dingtalk-stream in its thread when a message arrives.
|
||||
|
||||
Schedules the async handler on the main event loop.
|
||||
"""
|
||||
loop = self._loop
|
||||
if loop is None or loop.is_closed():
|
||||
logger.error("[DingTalk] Event loop unavailable, cannot dispatch message")
|
||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self._adapter._on_message(message), loop)
|
||||
try:
|
||||
future.result(timeout=60)
|
||||
except Exception:
|
||||
logger.exception("[DingTalk] Error processing incoming message")
|
||||
|
||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
||||
+200
-44
@@ -10,6 +10,7 @@ Uses discord.py library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
@@ -18,6 +19,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,6 +43,8 @@ from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -48,6 +52,8 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -434,8 +440,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention.
|
||||
self._bot_participated_threads: set = set()
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
@@ -519,6 +531,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
|
||||
# Allow both default and reply types — replies have a distinct MessageType.
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
@@ -1234,14 +1251,48 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
"""Start a persistent typing indicator for a channel.
|
||||
|
||||
Discord's TYPING_START gateway event is unreliable in DMs for bots.
|
||||
Instead, start a background loop that hits the typing endpoint every
|
||||
8 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||||
stop_typing() is called (after the response is sent).
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
# Don't start a duplicate loop
|
||||
if chat_id in self._typing_tasks:
|
||||
return
|
||||
|
||||
async def _typing_loop() -> None:
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
while True:
|
||||
try:
|
||||
route = discord.http.Route(
|
||||
"POST", "/channels/{channel_id}/typing",
|
||||
channel_id=chat_id,
|
||||
)
|
||||
await self._client.http.request(route)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for %s: %s", chat_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the persistent typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
@@ -1359,16 +1410,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
command_text: str,
|
||||
followup_msg: str = "Done~",
|
||||
followup_msg: str | None = None,
|
||||
) -> None:
|
||||
"""Common handler for simple slash commands that dispatch a command string."""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, command_text)
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send(followup_msg, ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
if followup_msg:
|
||||
try:
|
||||
await interaction.followup.send(followup_msg, ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
@@ -1377,19 +1429,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
tree = self._client.tree
|
||||
|
||||
@tree.command(name="ask", description="Ask Hermes a question")
|
||||
@discord.app_commands.describe(question="Your question for Hermes")
|
||||
async def slash_ask(interaction: discord.Interaction, question: str):
|
||||
await interaction.response.defer()
|
||||
event = self._build_slash_event(interaction, question)
|
||||
await self.handle_message(event)
|
||||
# The response is sent via the normal send() flow
|
||||
# Send a followup to close the interaction if needed
|
||||
try:
|
||||
await interaction.followup.send("Processing complete~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="new", description="Start a new conversation")
|
||||
async def slash_new(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/reset", "New conversation started~")
|
||||
@@ -1409,10 +1448,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/reasoning {effort}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="personality", description="Set a personality")
|
||||
@discord.app_commands.describe(name="Personality name. Leave empty to list available.")
|
||||
@@ -1488,10 +1523,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/voice {mode}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
@@ -1515,7 +1546,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
is_thread = isinstance(interaction.channel, discord.Thread)
|
||||
thread_id = None
|
||||
|
||||
if is_dm:
|
||||
chat_type = "dm"
|
||||
elif is_thread:
|
||||
chat_type = "thread"
|
||||
thread_id = str(interaction.channel_id)
|
||||
else:
|
||||
chat_type = "group"
|
||||
|
||||
chat_name = ""
|
||||
if not is_dm and hasattr(interaction.channel, "name"):
|
||||
chat_name = interaction.channel.name
|
||||
@@ -1531,6 +1572,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_type=chat_type,
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
@@ -1573,6 +1615,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
|
||||
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
|
||||
|
||||
# Track thread participation so follow-ups don't require @mention
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# If a message was provided, kick off a new Hermes session in the thread
|
||||
starter = (message or "").strip()
|
||||
if starter and thread_id:
|
||||
@@ -1740,9 +1786,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
# Discord embed description limit is 4096; show full command up to that
|
||||
max_desc = 4088
|
||||
cmd_display = command if len(command) <= max_desc else command[: max_desc - 3] + "..."
|
||||
embed = discord.Embed(
|
||||
title="Command Approval Required",
|
||||
description=f"```\n{command[:500]}\n```",
|
||||
description=f"```\n{cmd_display}\n```",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
embed.set_footer(text=f"Approval ID: {approval_id}")
|
||||
@@ -1798,6 +1847,49 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "discord_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load discord thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
# Trim to most recent entries if over cap
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save discord thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
@@ -1850,7 +1942,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
@@ -1867,7 +1959,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
doc_ext = ""
|
||||
if att.filename:
|
||||
_, doc_ext = os.path.splitext(att.filename)
|
||||
doc_ext = doc_ext.lower()
|
||||
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
@@ -1904,6 +2001,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
pending_text_injection: Optional[str] = None
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
@@ -1935,12 +2033,70 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
# Document attachments: download, cache, and optionally inject text
|
||||
ext = ""
|
||||
if att.filename:
|
||||
_, ext = os.path.splitext(att.filename)
|
||||
ext = ext.lower()
|
||||
if not ext and content_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(content_type, "")
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
logger.warning(
|
||||
"[Discord] Unsupported document type '%s' (%s), skipping",
|
||||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
att.size, att.filename,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if pending_text_injection:
|
||||
pending_text_injection = f"{pending_text_injection}\n\n{injection}"
|
||||
else:
|
||||
pending_text_injection = injection
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Failed to cache document %s: %s",
|
||||
att.filename, e, exc_info=True,
|
||||
)
|
||||
|
||||
event_text = message.content
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
@@ -1954,7 +2110,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._track_thread(thread_id)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
@@ -135,14 +135,23 @@ def _extract_email_address(raw: str) -> str:
|
||||
return raw.strip().lower()
|
||||
|
||||
|
||||
def _extract_attachments(msg: email_lib.message.Message) -> List[Dict[str, Any]]:
|
||||
"""Extract attachment metadata and cache files locally."""
|
||||
def _extract_attachments(
|
||||
msg: email_lib.message.Message,
|
||||
skip_attachments: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract attachment metadata and cache files locally.
|
||||
|
||||
When *skip_attachments* is True, all attachment/inline parts are ignored
|
||||
(useful for malware protection or bandwidth savings).
|
||||
"""
|
||||
attachments = []
|
||||
if not msg.is_multipart():
|
||||
return attachments
|
||||
|
||||
for part in msg.walk():
|
||||
disposition = str(part.get("Content-Disposition", ""))
|
||||
if skip_attachments and ("attachment" in disposition or "inline" in disposition):
|
||||
continue
|
||||
if "attachment" not in disposition and "inline" not in disposition:
|
||||
continue
|
||||
# Skip text/plain and text/html body parts
|
||||
@@ -196,6 +205,13 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
self._smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
|
||||
self._poll_interval = int(os.getenv("EMAIL_POLL_INTERVAL", "15"))
|
||||
|
||||
# Skip attachments — configured via config.yaml:
|
||||
# platforms:
|
||||
# email:
|
||||
# skip_attachments: true
|
||||
extra = config.extra or {}
|
||||
self._skip_attachments = extra.get("skip_attachments", False)
|
||||
|
||||
# Track message IDs we've already processed to avoid duplicates
|
||||
self._seen_uids: set = set()
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
@@ -214,7 +230,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
# Mark all existing messages as seen so we only process new ones
|
||||
imap.select("INBOX")
|
||||
status, data = imap.uid("search", None, "ALL")
|
||||
if status == "OK" and data[0]:
|
||||
if status == "OK" and data and data[0]:
|
||||
for uid in data[0].split():
|
||||
self._seen_uids.add(uid)
|
||||
imap.logout()
|
||||
@@ -279,7 +295,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
imap.select("INBOX")
|
||||
|
||||
status, data = imap.uid("search", None, "UNSEEN")
|
||||
if status != "OK" or not data[0]:
|
||||
if status != "OK" or not data or not data[0]:
|
||||
imap.logout()
|
||||
return results
|
||||
|
||||
@@ -306,7 +322,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
message_id = msg.get("Message-ID", "")
|
||||
in_reply_to = msg.get("In-Reply-To", "")
|
||||
body = _extract_text_body(msg)
|
||||
attachments = _extract_attachments(msg)
|
||||
attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments)
|
||||
|
||||
results.append({
|
||||
"uid": uid,
|
||||
@@ -436,7 +452,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
logger.info("[Email] Sent reply to %s (subject: %s)", to_addr, subject)
|
||||
return msg_id
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
async def send_typing(self, chat_id: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Email has no typing indicator — no-op."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,895 @@
|
||||
"""Matrix gateway adapter.
|
||||
|
||||
Connects to any Matrix homeserver (self-hosted or matrix.org) via the
|
||||
matrix-nio Python SDK. Supports optional end-to-end encryption (E2EE)
|
||||
when installed with ``pip install "matrix-nio[e2e]"``.
|
||||
|
||||
Environment variables:
|
||||
MATRIX_HOMESERVER Homeserver URL (e.g. https://matrix.example.org)
|
||||
MATRIX_ACCESS_TOKEN Access token (preferred auth method)
|
||||
MATRIX_USER_ID Full user ID (@bot:server) — required for password login
|
||||
MATRIX_PASSWORD Password (alternative to access token)
|
||||
MATRIX_ENCRYPTION Set "true" to enable E2EE
|
||||
MATRIX_ALLOWED_USERS Comma-separated Matrix user IDs (@user:server)
|
||||
MATRIX_HOME_ROOM Room ID for cron/notification delivery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Matrix message size limit (4000 chars practical, spec has no hard limit
|
||||
# but clients render poorly above this).
|
||||
MAX_MESSAGE_LENGTH = 4000
|
||||
|
||||
# Store directory for E2EE keys and sync state.
|
||||
_STORE_DIR = Path.home() / ".hermes" / "matrix" / "store"
|
||||
|
||||
# Grace period: ignore messages older than this many seconds before startup.
|
||||
_STARTUP_GRACE_SECONDS = 5
|
||||
|
||||
|
||||
def check_matrix_requirements() -> bool:
|
||||
"""Return True if the Matrix adapter can be used."""
|
||||
token = os.getenv("MATRIX_ACCESS_TOKEN", "")
|
||||
password = os.getenv("MATRIX_PASSWORD", "")
|
||||
homeserver = os.getenv("MATRIX_HOMESERVER", "")
|
||||
|
||||
if not token and not password:
|
||||
logger.debug("Matrix: neither MATRIX_ACCESS_TOKEN nor MATRIX_PASSWORD set")
|
||||
return False
|
||||
if not homeserver:
|
||||
logger.warning("Matrix: MATRIX_HOMESERVER not set")
|
||||
return False
|
||||
try:
|
||||
import nio # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Matrix: matrix-nio not installed. "
|
||||
"Run: pip install 'matrix-nio[e2e]'"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class MatrixAdapter(BasePlatformAdapter):
|
||||
"""Gateway adapter for Matrix (any homeserver)."""
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.MATRIX)
|
||||
|
||||
self._homeserver: str = (
|
||||
config.extra.get("homeserver", "")
|
||||
or os.getenv("MATRIX_HOMESERVER", "")
|
||||
).rstrip("/")
|
||||
self._access_token: str = config.token or os.getenv("MATRIX_ACCESS_TOKEN", "")
|
||||
self._user_id: str = (
|
||||
config.extra.get("user_id", "")
|
||||
or os.getenv("MATRIX_USER_ID", "")
|
||||
)
|
||||
self._password: str = (
|
||||
config.extra.get("password", "")
|
||||
or os.getenv("MATRIX_PASSWORD", "")
|
||||
)
|
||||
self._encryption: bool = config.extra.get(
|
||||
"encryption",
|
||||
os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes"),
|
||||
)
|
||||
|
||||
self._client: Any = None # nio.AsyncClient
|
||||
self._sync_task: Optional[asyncio.Task] = None
|
||||
self._closing = False
|
||||
self._startup_ts: float = 0.0
|
||||
|
||||
# Cache: room_id → bool (is DM)
|
||||
self._dm_rooms: Dict[str, bool] = {}
|
||||
# Set of room IDs we've joined
|
||||
self._joined_rooms: Set[str] = set()
|
||||
# Event deduplication (bounded deque keeps newest entries)
|
||||
from collections import deque
|
||||
self._processed_events: deque = deque(maxlen=1000)
|
||||
self._processed_events_set: set = set()
|
||||
|
||||
def _is_duplicate_event(self, event_id) -> bool:
|
||||
"""Return True if this event was already processed. Tracks the ID otherwise."""
|
||||
if not event_id:
|
||||
return False
|
||||
if event_id in self._processed_events_set:
|
||||
return True
|
||||
if len(self._processed_events) == self._processed_events.maxlen:
|
||||
evicted = self._processed_events[0]
|
||||
self._processed_events_set.discard(evicted)
|
||||
self._processed_events.append(event_id)
|
||||
self._processed_events_set.add(event_id)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Required overrides
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the Matrix homeserver and start syncing."""
|
||||
import nio
|
||||
|
||||
if not self._homeserver:
|
||||
logger.error("Matrix: homeserver URL not configured")
|
||||
return False
|
||||
|
||||
# Determine store path and ensure it exists.
|
||||
store_path = str(_STORE_DIR)
|
||||
_STORE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create the client.
|
||||
if self._encryption:
|
||||
try:
|
||||
client = nio.AsyncClient(
|
||||
self._homeserver,
|
||||
self._user_id or "",
|
||||
store_path=store_path,
|
||||
)
|
||||
logger.info("Matrix: E2EE enabled (store: %s)", store_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Matrix: failed to create E2EE client (%s), "
|
||||
"falling back to plain client. Install: "
|
||||
"pip install 'matrix-nio[e2e]'",
|
||||
exc,
|
||||
)
|
||||
client = nio.AsyncClient(self._homeserver, self._user_id or "")
|
||||
else:
|
||||
client = nio.AsyncClient(self._homeserver, self._user_id or "")
|
||||
|
||||
self._client = client
|
||||
|
||||
# Authenticate.
|
||||
if self._access_token:
|
||||
client.access_token = self._access_token
|
||||
# Resolve user_id if not set.
|
||||
if not self._user_id:
|
||||
resp = await client.whoami()
|
||||
if isinstance(resp, nio.WhoamiResponse):
|
||||
self._user_id = resp.user_id
|
||||
client.user_id = resp.user_id
|
||||
logger.info("Matrix: authenticated as %s", self._user_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
|
||||
)
|
||||
await client.close()
|
||||
return False
|
||||
else:
|
||||
client.user_id = self._user_id
|
||||
logger.info("Matrix: using access token for %s", self._user_id)
|
||||
elif self._password and self._user_id:
|
||||
resp = await client.login(
|
||||
self._password,
|
||||
device_name="Hermes Agent",
|
||||
)
|
||||
if isinstance(resp, nio.LoginResponse):
|
||||
logger.info("Matrix: logged in as %s", self._user_id)
|
||||
else:
|
||||
logger.error("Matrix: login failed — %s", getattr(resp, "message", resp))
|
||||
await client.close()
|
||||
return False
|
||||
else:
|
||||
logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD")
|
||||
await client.close()
|
||||
return False
|
||||
|
||||
# If E2EE is enabled, load the crypto store.
|
||||
if self._encryption and hasattr(client, "olm"):
|
||||
try:
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
logger.info("Matrix: E2EE crypto initialized")
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: crypto init issue: %s", exc)
|
||||
|
||||
# Register event callbacks.
|
||||
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageFile)
|
||||
client.add_event_callback(self._on_invite, nio.InviteMemberEvent)
|
||||
|
||||
# If E2EE: handle encrypted events.
|
||||
if self._encryption and hasattr(client, "olm"):
|
||||
client.add_event_callback(
|
||||
self._on_room_message, nio.MegolmEvent
|
||||
)
|
||||
|
||||
# Initial sync to catch up, then start background sync.
|
||||
self._startup_ts = time.time()
|
||||
self._closing = False
|
||||
|
||||
# Do an initial sync to populate room state.
|
||||
resp = await client.sync(timeout=10000, full_state=True)
|
||||
if isinstance(resp, nio.SyncResponse):
|
||||
self._joined_rooms = set(resp.rooms.join.keys())
|
||||
logger.info(
|
||||
"Matrix: initial sync complete, joined %d rooms",
|
||||
len(self._joined_rooms),
|
||||
)
|
||||
# Build DM room cache from m.direct account data.
|
||||
await self._refresh_dm_cache()
|
||||
else:
|
||||
logger.warning("Matrix: initial sync returned %s", type(resp).__name__)
|
||||
|
||||
# Start the sync loop.
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Matrix."""
|
||||
self._closing = True
|
||||
|
||||
if self._sync_task and not self._sync_task.done():
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
logger.info("Matrix: disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a message to a Matrix room."""
|
||||
import nio
|
||||
|
||||
if not content:
|
||||
return SendResult(success=True)
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, MAX_MESSAGE_LENGTH)
|
||||
|
||||
last_event_id = None
|
||||
for chunk in chunks:
|
||||
msg_content: Dict[str, Any] = {
|
||||
"msgtype": "m.text",
|
||||
"body": chunk,
|
||||
}
|
||||
|
||||
# Convert markdown to HTML for rich rendering.
|
||||
html = self._markdown_to_html(chunk)
|
||||
if html and html != chunk:
|
||||
msg_content["format"] = "org.matrix.custom.html"
|
||||
msg_content["formatted_body"] = html
|
||||
|
||||
# Reply-to support.
|
||||
if reply_to:
|
||||
msg_content["m.relates_to"] = {
|
||||
"m.in_reply_to": {"event_id": reply_to}
|
||||
}
|
||||
|
||||
# Thread support: if metadata has thread_id, send as threaded reply.
|
||||
thread_id = (metadata or {}).get("thread_id")
|
||||
if thread_id:
|
||||
relates_to = msg_content.get("m.relates_to", {})
|
||||
relates_to["rel_type"] = "m.thread"
|
||||
relates_to["event_id"] = thread_id
|
||||
relates_to["is_falling_back"] = True
|
||||
if reply_to and "m.in_reply_to" not in relates_to:
|
||||
relates_to["m.in_reply_to"] = {"event_id": reply_to}
|
||||
msg_content["m.relates_to"] = relates_to
|
||||
|
||||
resp = await self._client.room_send(
|
||||
chat_id,
|
||||
"m.room.message",
|
||||
msg_content,
|
||||
)
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
last_event_id = resp.event_id
|
||||
else:
|
||||
err = getattr(resp, "message", str(resp))
|
||||
logger.error("Matrix: failed to send to %s: %s", chat_id, err)
|
||||
return SendResult(success=False, error=err)
|
||||
|
||||
return SendResult(success=True, message_id=last_event_id)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return room name and type (dm/group)."""
|
||||
name = chat_id
|
||||
chat_type = "group"
|
||||
|
||||
if self._client:
|
||||
room = self._client.rooms.get(chat_id)
|
||||
if room:
|
||||
name = room.display_name or room.canonical_alias or chat_id
|
||||
# Use DM cache.
|
||||
if self._dm_rooms.get(chat_id, False):
|
||||
chat_type = "dm"
|
||||
elif room.member_count == 2:
|
||||
chat_type = "dm"
|
||||
|
||||
return {"name": name, "type": chat_type}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional overrides
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_typing(
|
||||
self, chat_id: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Send a typing indicator."""
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.room_typing(chat_id, typing_state=True, timeout=30000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def edit_message(
|
||||
self, chat_id: str, message_id: str, content: str
|
||||
) -> SendResult:
|
||||
"""Edit an existing message (via m.replace)."""
|
||||
import nio
|
||||
|
||||
formatted = self.format_message(content)
|
||||
msg_content: Dict[str, Any] = {
|
||||
"msgtype": "m.text",
|
||||
"body": f"* {formatted}",
|
||||
"m.new_content": {
|
||||
"msgtype": "m.text",
|
||||
"body": formatted,
|
||||
},
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": message_id,
|
||||
},
|
||||
}
|
||||
|
||||
html = self._markdown_to_html(formatted)
|
||||
if html and html != formatted:
|
||||
msg_content["m.new_content"]["format"] = "org.matrix.custom.html"
|
||||
msg_content["m.new_content"]["formatted_body"] = html
|
||||
msg_content["format"] = "org.matrix.custom.html"
|
||||
msg_content["formatted_body"] = f"* {html}"
|
||||
|
||||
resp = await self._client.room_send(chat_id, "m.room.message", msg_content)
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
return SendResult(success=True, message_id=resp.event_id)
|
||||
return SendResult(success=False, error=getattr(resp, "message", str(resp)))
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Download an image URL and upload it to Matrix."""
|
||||
try:
|
||||
# Try aiohttp first (always available), fall back to httpx
|
||||
try:
|
||||
import aiohttp as _aiohttp
|
||||
async with _aiohttp.ClientSession() as http:
|
||||
async with http.get(image_url, timeout=_aiohttp.ClientTimeout(total=30)) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.read()
|
||||
ct = resp.content_type or "image/png"
|
||||
fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png"
|
||||
except ImportError:
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as http:
|
||||
resp = await http.get(image_url, follow_redirects=True, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
ct = resp.headers.get("content-type", "image/png")
|
||||
fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: failed to download image %s: %s", image_url, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{image_url}".strip(), reply_to)
|
||||
|
||||
return await self._upload_and_send(chat_id, data, fname, ct, "m.image", caption, reply_to, metadata)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a local image file to Matrix."""
|
||||
return await self._send_local_file(chat_id, image_path, "m.image", caption, reply_to, metadata=metadata)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a local file as a document."""
|
||||
return await self._send_local_file(chat_id, file_path, "m.file", caption, reply_to, file_name, metadata)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload an audio file as a voice message."""
|
||||
return await self._send_local_file(chat_id, audio_path, "m.audio", caption, reply_to, metadata=metadata)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a video file."""
|
||||
return await self._send_local_file(chat_id, video_path, "m.video", caption, reply_to, metadata=metadata)
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Pass-through — Matrix supports standard Markdown natively."""
|
||||
# Strip image markdown; media is uploaded separately.
|
||||
content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content)
|
||||
return content
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _upload_and_send(
|
||||
self,
|
||||
room_id: str,
|
||||
data: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
msgtype: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload bytes to Matrix and send as a media message."""
|
||||
import nio
|
||||
|
||||
# Upload to homeserver.
|
||||
resp = await self._client.upload(
|
||||
data,
|
||||
content_type=content_type,
|
||||
filename=filename,
|
||||
)
|
||||
if not isinstance(resp, nio.UploadResponse):
|
||||
err = getattr(resp, "message", str(resp))
|
||||
logger.error("Matrix: upload failed: %s", err)
|
||||
return SendResult(success=False, error=err)
|
||||
|
||||
mxc_url = resp.content_uri
|
||||
|
||||
# Build media message content.
|
||||
msg_content: Dict[str, Any] = {
|
||||
"msgtype": msgtype,
|
||||
"body": caption or filename,
|
||||
"url": mxc_url,
|
||||
"info": {
|
||||
"mimetype": content_type,
|
||||
"size": len(data),
|
||||
},
|
||||
}
|
||||
|
||||
if reply_to:
|
||||
msg_content["m.relates_to"] = {
|
||||
"m.in_reply_to": {"event_id": reply_to}
|
||||
}
|
||||
|
||||
thread_id = (metadata or {}).get("thread_id")
|
||||
if thread_id:
|
||||
relates_to = msg_content.get("m.relates_to", {})
|
||||
relates_to["rel_type"] = "m.thread"
|
||||
relates_to["event_id"] = thread_id
|
||||
relates_to["is_falling_back"] = True
|
||||
msg_content["m.relates_to"] = relates_to
|
||||
|
||||
resp2 = await self._client.room_send(room_id, "m.room.message", msg_content)
|
||||
if isinstance(resp2, nio.RoomSendResponse):
|
||||
return SendResult(success=True, message_id=resp2.event_id)
|
||||
return SendResult(success=False, error=getattr(resp2, "message", str(resp2)))
|
||||
|
||||
async def _send_local_file(
|
||||
self,
|
||||
room_id: str,
|
||||
file_path: str,
|
||||
msgtype: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Read a local file and upload it."""
|
||||
p = Path(file_path)
|
||||
if not p.exists():
|
||||
return await self.send(
|
||||
room_id, f"{caption or ''}\n(file not found: {file_path})", reply_to
|
||||
)
|
||||
|
||||
fname = file_name or p.name
|
||||
ct = mimetypes.guess_type(fname)[0] or "application/octet-stream"
|
||||
data = p.read_bytes()
|
||||
|
||||
return await self._upload_and_send(room_id, data, fname, ct, msgtype, caption, reply_to, metadata)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._client.sync(timeout=30000)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning("Matrix: sync error: %s — retrying in 5s", exc)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event callbacks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _on_room_message(self, room: Any, event: Any) -> None:
|
||||
"""Handle incoming text messages (and decrypted megolm events)."""
|
||||
import nio
|
||||
|
||||
# Ignore own messages.
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID (nio can fire the same event more than once).
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace: ignore old messages from initial sync.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
return
|
||||
|
||||
# Handle decrypted MegolmEvents — extract the inner event.
|
||||
if isinstance(event, nio.MegolmEvent):
|
||||
# Failed to decrypt.
|
||||
logger.warning(
|
||||
"Matrix: could not decrypt event %s in %s",
|
||||
event.event_id, room.room_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Skip edits (m.replace relation).
|
||||
source_content = getattr(event, "source", {}).get("content", {})
|
||||
relates_to = source_content.get("m.relates_to", {})
|
||||
if relates_to.get("rel_type") == "m.replace":
|
||||
return
|
||||
|
||||
body = getattr(event, "body", "") or ""
|
||||
if not body:
|
||||
return
|
||||
|
||||
# Determine chat type.
|
||||
is_dm = self._dm_rooms.get(room.room_id, False)
|
||||
if not is_dm and room.member_count == 2:
|
||||
is_dm = True
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
|
||||
# Thread support.
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
|
||||
# Reply-to detection.
|
||||
reply_to = None
|
||||
in_reply_to = relates_to.get("m.in_reply_to", {})
|
||||
if in_reply_to:
|
||||
reply_to = in_reply_to.get("event_id")
|
||||
|
||||
# Strip reply fallback from body (Matrix prepends "> ..." lines).
|
||||
if reply_to and body.startswith("> "):
|
||||
lines = body.split("\n")
|
||||
stripped = []
|
||||
past_fallback = False
|
||||
for line in lines:
|
||||
if not past_fallback:
|
||||
if line.startswith("> ") or line == ">":
|
||||
continue
|
||||
if line == "":
|
||||
past_fallback = True
|
||||
continue
|
||||
past_fallback = True
|
||||
stripped.append(line)
|
||||
body = "\n".join(stripped) if stripped else body
|
||||
|
||||
# Message type.
|
||||
msg_type = MessageType.TEXT
|
||||
if body.startswith("!") or body.startswith("/"):
|
||||
msg_type = MessageType.COMMAND
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=room.room_id,
|
||||
chat_type=chat_type,
|
||||
user_id=event.sender,
|
||||
user_name=self._get_display_name(room, event.sender),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=body,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=getattr(event, "source", {}),
|
||||
message_id=event.event_id,
|
||||
reply_to_message_id=reply_to,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
async def _on_room_message_media(self, room: Any, event: Any) -> None:
|
||||
"""Handle incoming media messages (images, audio, video, files)."""
|
||||
import nio
|
||||
|
||||
# Ignore own messages.
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID.
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
return
|
||||
|
||||
body = getattr(event, "body", "") or ""
|
||||
url = getattr(event, "url", "")
|
||||
|
||||
# Convert mxc:// to HTTP URL for downstream processing.
|
||||
http_url = ""
|
||||
if url and url.startswith("mxc://"):
|
||||
http_url = self._mxc_to_http(url)
|
||||
|
||||
# Determine message type from event class.
|
||||
# Use the MIME type from the event's content info when available,
|
||||
# falling back to category-level MIME types for downstream matching
|
||||
# (gateway/run.py checks startswith("image/"), startswith("audio/"), etc.)
|
||||
content_info = getattr(event, "content", {}) if isinstance(getattr(event, "content", None), dict) else {}
|
||||
event_mimetype = (content_info.get("info") or {}).get("mimetype", "")
|
||||
media_type = "application/octet-stream"
|
||||
msg_type = MessageType.DOCUMENT
|
||||
if isinstance(event, nio.RoomMessageImage):
|
||||
msg_type = MessageType.PHOTO
|
||||
media_type = event_mimetype or "image/png"
|
||||
elif isinstance(event, nio.RoomMessageAudio):
|
||||
msg_type = MessageType.AUDIO
|
||||
media_type = event_mimetype or "audio/ogg"
|
||||
elif isinstance(event, nio.RoomMessageVideo):
|
||||
msg_type = MessageType.VIDEO
|
||||
media_type = event_mimetype or "video/mp4"
|
||||
elif event_mimetype:
|
||||
media_type = event_mimetype
|
||||
|
||||
# For images, download and cache locally so vision tools can access them.
|
||||
# Matrix MXC URLs require authentication, so direct URL access fails.
|
||||
cached_path = None
|
||||
if msg_type == MessageType.PHOTO and url:
|
||||
try:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png",
|
||||
"image/gif": ".gif", "image/webp": ".webp",
|
||||
}
|
||||
ext = ext_map.get(event_mimetype, ".jpg")
|
||||
download_resp = await self._client.download(url)
|
||||
if isinstance(download_resp, nio.DownloadResponse):
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
cached_path = cache_image_from_bytes(download_resp.body, ext=ext)
|
||||
logger.info("[Matrix] Cached user image at %s", cached_path)
|
||||
except Exception as e:
|
||||
logger.warning("[Matrix] Failed to cache image: %s", e)
|
||||
|
||||
is_dm = self._dm_rooms.get(room.room_id, False)
|
||||
if not is_dm and room.member_count == 2:
|
||||
is_dm = True
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
|
||||
# Thread/reply detection.
|
||||
source_content = getattr(event, "source", {}).get("content", {})
|
||||
relates_to = source_content.get("m.relates_to", {})
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=room.room_id,
|
||||
chat_type=chat_type,
|
||||
user_id=event.sender,
|
||||
user_name=self._get_display_name(room, event.sender),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Use cached local path for images, HTTP URL for other media types
|
||||
media_urls = [cached_path] if cached_path else ([http_url] if http_url else None)
|
||||
media_types = [media_type] if media_urls else None
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=body,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=getattr(event, "source", {}),
|
||||
message_id=event.event_id,
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
async def _on_invite(self, room: Any, event: Any) -> None:
|
||||
"""Auto-join rooms when invited."""
|
||||
import nio
|
||||
|
||||
if not isinstance(event, nio.InviteMemberEvent):
|
||||
return
|
||||
|
||||
# Only process invites directed at us.
|
||||
if event.state_key != self._user_id:
|
||||
return
|
||||
|
||||
if event.membership != "invite":
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Matrix: invited to %s by %s — joining",
|
||||
room.room_id, event.sender,
|
||||
)
|
||||
try:
|
||||
resp = await self._client.join(room.room_id)
|
||||
if isinstance(resp, nio.JoinResponse):
|
||||
self._joined_rooms.add(room.room_id)
|
||||
logger.info("Matrix: joined %s", room.room_id)
|
||||
# Refresh DM cache since new room may be a DM.
|
||||
await self._refresh_dm_cache()
|
||||
else:
|
||||
logger.warning(
|
||||
"Matrix: failed to join %s: %s",
|
||||
room.room_id, getattr(resp, "message", resp),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: error joining %s: %s", room.room_id, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _refresh_dm_cache(self) -> None:
|
||||
"""Refresh the DM room cache from m.direct account data.
|
||||
|
||||
Tries the account_data API first, then falls back to parsing
|
||||
the sync response's account_data for robustness.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
dm_data: Optional[Dict] = None
|
||||
|
||||
# Primary: try the dedicated account data endpoint.
|
||||
try:
|
||||
resp = await self._client.get_account_data("m.direct")
|
||||
if hasattr(resp, "content"):
|
||||
dm_data = resp.content
|
||||
elif isinstance(resp, dict):
|
||||
dm_data = resp
|
||||
except Exception as exc:
|
||||
logger.debug("Matrix: get_account_data('m.direct') failed: %s — trying sync fallback", exc)
|
||||
|
||||
# Fallback: parse from the client's account_data store (populated by sync).
|
||||
if dm_data is None:
|
||||
try:
|
||||
# matrix-nio stores account data events on the client object
|
||||
ad = getattr(self._client, "account_data", None)
|
||||
if ad and isinstance(ad, dict) and "m.direct" in ad:
|
||||
event = ad["m.direct"]
|
||||
if hasattr(event, "content"):
|
||||
dm_data = event.content
|
||||
elif isinstance(event, dict):
|
||||
dm_data = event
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if dm_data is None:
|
||||
return
|
||||
|
||||
dm_room_ids: Set[str] = set()
|
||||
for user_id, rooms in dm_data.items():
|
||||
if isinstance(rooms, list):
|
||||
dm_room_ids.update(rooms)
|
||||
|
||||
self._dm_rooms = {
|
||||
rid: (rid in dm_room_ids)
|
||||
for rid in self._joined_rooms
|
||||
}
|
||||
|
||||
def _get_display_name(self, room: Any, user_id: str) -> str:
|
||||
"""Get a user's display name in a room, falling back to user_id."""
|
||||
if room and hasattr(room, "users"):
|
||||
user = room.users.get(user_id)
|
||||
if user and getattr(user, "display_name", None):
|
||||
return user.display_name
|
||||
# Strip the @...:server format to just the localpart.
|
||||
if user_id.startswith("@") and ":" in user_id:
|
||||
return user_id[1:].split(":")[0]
|
||||
return user_id
|
||||
|
||||
def _mxc_to_http(self, mxc_url: str) -> str:
|
||||
"""Convert mxc://server/media_id to an HTTP download URL."""
|
||||
# mxc://matrix.org/abc123 → https://matrix.org/_matrix/client/v1/media/download/matrix.org/abc123
|
||||
# Uses the authenticated client endpoint (spec v1.11+) instead of the
|
||||
# deprecated /_matrix/media/v3/download/ path.
|
||||
if not mxc_url.startswith("mxc://"):
|
||||
return mxc_url
|
||||
parts = mxc_url[6:] # strip mxc://
|
||||
# Use our homeserver for download (federation handles the rest).
|
||||
return f"{self._homeserver}/_matrix/client/v1/media/download/{parts}"
|
||||
|
||||
def _markdown_to_html(self, text: str) -> str:
|
||||
"""Convert Markdown to Matrix-compatible HTML.
|
||||
|
||||
Uses a simple conversion for common patterns. For full fidelity
|
||||
a markdown-it style library could be used, but this covers the
|
||||
common cases without an extra dependency.
|
||||
"""
|
||||
try:
|
||||
import markdown
|
||||
html = markdown.markdown(
|
||||
text,
|
||||
extensions=["fenced_code", "tables", "nl2br"],
|
||||
)
|
||||
# Strip wrapping <p> tags for single-paragraph messages.
|
||||
if html.count("<p>") == 1:
|
||||
html = html.replace("<p>", "").replace("</p>", "")
|
||||
return html
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Minimal fallback: just handle bold, italic, code.
|
||||
html = text
|
||||
html = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", html)
|
||||
html = re.sub(r"\*(.+?)\*", r"<em>\1</em>", html)
|
||||
html = re.sub(r"`([^`]+)`", r"<code>\1</code>", html)
|
||||
html = re.sub(r"\n", r"<br>", html)
|
||||
return html
|
||||
@@ -0,0 +1,682 @@
|
||||
"""Mattermost gateway adapter.
|
||||
|
||||
Connects to a self-hosted (or cloud) Mattermost instance via its REST API
|
||||
(v4) and WebSocket for real-time events. No external Mattermost library
|
||||
required — uses aiohttp which is already a Hermes dependency.
|
||||
|
||||
Environment variables:
|
||||
MATTERMOST_URL Server URL (e.g. https://mm.example.com)
|
||||
MATTERMOST_TOKEN Bot token or personal-access token
|
||||
MATTERMOST_ALLOWED_USERS Comma-separated user IDs
|
||||
MATTERMOST_HOME_CHANNEL Channel ID for cron/notification delivery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mattermost post size limit (server default is 16383, but 4000 is the
|
||||
# practical limit for readable messages — matching OpenClaw's choice).
|
||||
MAX_POST_LENGTH = 4000
|
||||
|
||||
# Channel type codes returned by the Mattermost API.
|
||||
_CHANNEL_TYPE_MAP = {
|
||||
"D": "dm",
|
||||
"G": "group",
|
||||
"P": "group", # private channel → treat as group
|
||||
"O": "channel",
|
||||
}
|
||||
|
||||
# Reconnect parameters (exponential backoff).
|
||||
_RECONNECT_BASE_DELAY = 2.0
|
||||
_RECONNECT_MAX_DELAY = 60.0
|
||||
_RECONNECT_JITTER = 0.2
|
||||
|
||||
|
||||
def check_mattermost_requirements() -> bool:
|
||||
"""Return True if the Mattermost adapter can be used."""
|
||||
token = os.getenv("MATTERMOST_TOKEN", "")
|
||||
url = os.getenv("MATTERMOST_URL", "")
|
||||
if not token:
|
||||
logger.debug("Mattermost: MATTERMOST_TOKEN not set")
|
||||
return False
|
||||
if not url:
|
||||
logger.warning("Mattermost: MATTERMOST_URL not set")
|
||||
return False
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning("Mattermost: aiohttp not installed")
|
||||
return False
|
||||
|
||||
|
||||
class MattermostAdapter(BasePlatformAdapter):
|
||||
"""Gateway adapter for Mattermost (self-hosted or cloud)."""
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.MATTERMOST)
|
||||
|
||||
self._base_url: str = (
|
||||
config.extra.get("url", "")
|
||||
or os.getenv("MATTERMOST_URL", "")
|
||||
).rstrip("/")
|
||||
self._token: str = config.token or os.getenv("MATTERMOST_TOKEN", "")
|
||||
|
||||
self._bot_user_id: str = ""
|
||||
self._bot_username: str = ""
|
||||
|
||||
# aiohttp session + websocket handle
|
||||
self._session: Any = None # aiohttp.ClientSession
|
||||
self._ws: Any = None # aiohttp.ClientWebSocketResponse
|
||||
self._ws_task: Optional[asyncio.Task] = None
|
||||
self._reconnect_task: Optional[asyncio.Task] = None
|
||||
self._closing = False
|
||||
|
||||
# Reply mode: "thread" to nest replies, "off" for flat messages.
|
||||
self._reply_mode: str = (
|
||||
config.extra.get("reply_mode", "")
|
||||
or os.getenv("MATTERMOST_REPLY_MODE", "off")
|
||||
).lower()
|
||||
|
||||
# Dedup cache: post_id → timestamp (prevent reprocessing)
|
||||
self._seen_posts: Dict[str, float] = {}
|
||||
self._SEEN_MAX = 2000
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _api_get(self, path: str) -> Dict[str, Any]:
|
||||
"""GET /api/v4/{path}."""
|
||||
import aiohttp
|
||||
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
|
||||
try:
|
||||
async with self._session.get(url, headers=self._headers()) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM API GET %s → %s: %s", path, resp.status, body[:200])
|
||||
return {}
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("MM API GET %s network error: %s", path, exc)
|
||||
return {}
|
||||
|
||||
async def _api_post(
|
||||
self, path: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""POST /api/v4/{path} with JSON body."""
|
||||
import aiohttp
|
||||
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=self._headers(), json=payload
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM API POST %s → %s: %s", path, resp.status, body[:200])
|
||||
return {}
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("MM API POST %s network error: %s", path, exc)
|
||||
return {}
|
||||
|
||||
async def _api_put(
|
||||
self, path: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""PUT /api/v4/{path} with JSON body."""
|
||||
import aiohttp
|
||||
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
|
||||
try:
|
||||
async with self._session.put(
|
||||
url, headers=self._headers(), json=payload
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM API PUT %s → %s: %s", path, resp.status, body[:200])
|
||||
return {}
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("MM API PUT %s network error: %s", path, exc)
|
||||
return {}
|
||||
|
||||
async def _upload_file(
|
||||
self, channel_id: str, file_data: bytes, filename: str, content_type: str = "application/octet-stream"
|
||||
) -> Optional[str]:
|
||||
"""Upload a file and return its file ID, or None on failure."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{self._base_url}/api/v4/files"
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("channel_id", channel_id)
|
||||
form.add_field(
|
||||
"files",
|
||||
file_data,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {self._token}"}
|
||||
async with self._session.post(url, headers=headers, data=form) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM file upload → %s: %s", resp.status, body[:200])
|
||||
return None
|
||||
data = await resp.json()
|
||||
infos = data.get("file_infos", [])
|
||||
return infos[0]["id"] if infos else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Required overrides
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Mattermost and start the WebSocket listener."""
|
||||
import aiohttp
|
||||
|
||||
if not self._base_url or not self._token:
|
||||
logger.error("Mattermost: URL or token not configured")
|
||||
return False
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._closing = False
|
||||
|
||||
# Verify credentials and fetch bot identity.
|
||||
me = await self._api_get("users/me")
|
||||
if not me or "id" not in me:
|
||||
logger.error("Mattermost: failed to authenticate — check MATTERMOST_TOKEN and MATTERMOST_URL")
|
||||
await self._session.close()
|
||||
return False
|
||||
|
||||
self._bot_user_id = me["id"]
|
||||
self._bot_username = me.get("username", "")
|
||||
logger.info(
|
||||
"Mattermost: authenticated as @%s (%s) on %s",
|
||||
self._bot_username,
|
||||
self._bot_user_id,
|
||||
self._base_url,
|
||||
)
|
||||
|
||||
# Start WebSocket in background.
|
||||
self._ws_task = asyncio.create_task(self._ws_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Mattermost."""
|
||||
self._closing = True
|
||||
|
||||
if self._ws_task and not self._ws_task.done():
|
||||
self._ws_task.cancel()
|
||||
try:
|
||||
await self._ws_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
if self._reconnect_task and not self._reconnect_task.done():
|
||||
self._reconnect_task.cancel()
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("Mattermost: disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a message (or multiple chunks) to a channel."""
|
||||
if not content:
|
||||
return SendResult(success=True)
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, MAX_POST_LENGTH)
|
||||
|
||||
last_id = None
|
||||
for chunk in chunks:
|
||||
payload: Dict[str, Any] = {
|
||||
"channel_id": chat_id,
|
||||
"message": chunk,
|
||||
}
|
||||
# Thread support: reply_to is the root post ID.
|
||||
if reply_to and self._reply_mode == "thread":
|
||||
payload["root_id"] = reply_to
|
||||
|
||||
data = await self._api_post("posts", payload)
|
||||
if not data or "id" not in data:
|
||||
return SendResult(success=False, error="Failed to create post")
|
||||
last_id = data["id"]
|
||||
|
||||
return SendResult(success=True, message_id=last_id)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return channel name and type."""
|
||||
data = await self._api_get(f"channels/{chat_id}")
|
||||
if not data:
|
||||
return {"name": chat_id, "type": "channel"}
|
||||
|
||||
ch_type = _CHANNEL_TYPE_MAP.get(data.get("type", "O"), "channel")
|
||||
display_name = data.get("display_name") or data.get("name") or chat_id
|
||||
return {"name": display_name, "type": ch_type}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional overrides
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_typing(
|
||||
self, chat_id: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Send a typing indicator."""
|
||||
await self._api_post(
|
||||
f"users/{self._bot_user_id}/typing",
|
||||
{"channel_id": chat_id},
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self, chat_id: str, message_id: str, content: str
|
||||
) -> SendResult:
|
||||
"""Edit an existing post."""
|
||||
formatted = self.format_message(content)
|
||||
data = await self._api_put(
|
||||
f"posts/{message_id}/patch",
|
||||
{"message": formatted},
|
||||
)
|
||||
if not data or "id" not in data:
|
||||
return SendResult(success=False, error="Failed to edit post")
|
||||
return SendResult(success=True, message_id=data["id"])
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Download an image and upload it as a file attachment."""
|
||||
return await self._send_url_as_file(
|
||||
chat_id, image_url, caption, reply_to, "image"
|
||||
)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a local image file."""
|
||||
return await self._send_local_file(
|
||||
chat_id, image_path, caption, reply_to
|
||||
)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a local file as a document."""
|
||||
return await self._send_local_file(
|
||||
chat_id, file_path, caption, reply_to, file_name
|
||||
)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload an audio file."""
|
||||
return await self._send_local_file(
|
||||
chat_id, audio_path, caption, reply_to
|
||||
)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a video file."""
|
||||
return await self._send_local_file(
|
||||
chat_id, video_path, caption, reply_to
|
||||
)
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Mattermost uses standard Markdown — mostly pass through.
|
||||
|
||||
Strip image markdown into plain links (files are uploaded separately).
|
||||
"""
|
||||
# Convert  to just the URL — Mattermost renders
|
||||
# image URLs as inline previews automatically.
|
||||
content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content)
|
||||
return content
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_url_as_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
url: str,
|
||||
caption: Optional[str],
|
||||
reply_to: Optional[str],
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
import aiohttp
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
# Fall back to sending the URL as text.
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
# Derive filename from URL.
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: failed to download %s: %s", url, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
if not file_id:
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"channel_id": chat_id,
|
||||
"message": caption or "",
|
||||
"file_ids": [file_id],
|
||||
}
|
||||
if reply_to and self._reply_mode == "thread":
|
||||
payload["root_id"] = reply_to
|
||||
|
||||
data = await self._api_post("posts", payload)
|
||||
if not data or "id" not in data:
|
||||
return SendResult(success=False, error="Failed to post with file")
|
||||
return SendResult(success=True, message_id=data["id"])
|
||||
|
||||
async def _send_local_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str],
|
||||
reply_to: Optional[str],
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Upload a local file and attach it to a post."""
|
||||
import mimetypes
|
||||
|
||||
p = Path(file_path)
|
||||
if not p.exists():
|
||||
return await self.send(
|
||||
chat_id, f"{caption or ''}\n(file not found: {file_path})", reply_to
|
||||
)
|
||||
|
||||
fname = file_name or p.name
|
||||
ct = mimetypes.guess_type(fname)[0] or "application/octet-stream"
|
||||
file_data = p.read_bytes()
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
if not file_id:
|
||||
return SendResult(success=False, error="File upload failed")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"channel_id": chat_id,
|
||||
"message": caption or "",
|
||||
"file_ids": [file_id],
|
||||
}
|
||||
if reply_to and self._reply_mode == "thread":
|
||||
payload["root_id"] = reply_to
|
||||
|
||||
data = await self._api_post("posts", payload)
|
||||
if not data or "id" not in data:
|
||||
return SendResult(success=False, error="Failed to post with file")
|
||||
return SendResult(success=True, message_id=data["id"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ws_loop(self) -> None:
|
||||
"""Connect to the WebSocket and listen for events, reconnecting on failure."""
|
||||
delay = _RECONNECT_BASE_DELAY
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._ws_connect_and_listen()
|
||||
# Clean disconnect — reset delay.
|
||||
delay = _RECONNECT_BASE_DELAY
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning("Mattermost WS error: %s — reconnecting in %.0fs", exc, delay)
|
||||
|
||||
if self._closing:
|
||||
return
|
||||
|
||||
# Exponential backoff with jitter.
|
||||
import random
|
||||
jitter = delay * _RECONNECT_JITTER * random.random()
|
||||
await asyncio.sleep(delay + jitter)
|
||||
delay = min(delay * 2, _RECONNECT_MAX_DELAY)
|
||||
|
||||
async def _ws_connect_and_listen(self) -> None:
|
||||
"""Single WebSocket session: connect, authenticate, process events."""
|
||||
# Build WS URL: https:// → wss://, http:// → ws://
|
||||
ws_url = re.sub(r"^http", "ws", self._base_url) + "/api/v4/websocket"
|
||||
logger.info("Mattermost: connecting to %s", ws_url)
|
||||
|
||||
self._ws = await self._session.ws_connect(ws_url, heartbeat=30.0)
|
||||
|
||||
# Authenticate via the WebSocket.
|
||||
auth_msg = {
|
||||
"seq": 1,
|
||||
"action": "authentication_challenge",
|
||||
"data": {"token": self._token},
|
||||
}
|
||||
await self._ws.send_json(auth_msg)
|
||||
logger.info("Mattermost: WebSocket connected and authenticated")
|
||||
|
||||
async for raw_msg in self._ws:
|
||||
if self._closing:
|
||||
return
|
||||
|
||||
if raw_msg.type in (
|
||||
raw_msg.type.TEXT,
|
||||
raw_msg.type.BINARY,
|
||||
):
|
||||
try:
|
||||
event = json.loads(raw_msg.data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
await self._handle_ws_event(event)
|
||||
elif raw_msg.type in (
|
||||
raw_msg.type.ERROR,
|
||||
raw_msg.type.CLOSE,
|
||||
raw_msg.type.CLOSING,
|
||||
raw_msg.type.CLOSED,
|
||||
):
|
||||
logger.info("Mattermost: WebSocket closed (%s)", raw_msg.type)
|
||||
break
|
||||
|
||||
async def _handle_ws_event(self, event: Dict[str, Any]) -> None:
|
||||
"""Process a single WebSocket event."""
|
||||
event_type = event.get("event")
|
||||
if event_type != "posted":
|
||||
return
|
||||
|
||||
data = event.get("data", {})
|
||||
raw_post_str = data.get("post")
|
||||
if not raw_post_str:
|
||||
return
|
||||
|
||||
try:
|
||||
post = json.loads(raw_post_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return
|
||||
|
||||
# Ignore own messages.
|
||||
if post.get("user_id") == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Ignore system posts.
|
||||
if post.get("type"):
|
||||
return
|
||||
|
||||
post_id = post.get("id", "")
|
||||
|
||||
# Dedup.
|
||||
self._prune_seen()
|
||||
if post_id in self._seen_posts:
|
||||
return
|
||||
self._seen_posts[post_id] = time.time()
|
||||
|
||||
# Build message event.
|
||||
channel_id = post.get("channel_id", "")
|
||||
channel_type_raw = data.get("channel_type", "O")
|
||||
chat_type = _CHANNEL_TYPE_MAP.get(channel_type_raw, "channel")
|
||||
|
||||
# For DMs, user_id is sufficient. For channels, check for @mention.
|
||||
message_text = post.get("message", "")
|
||||
|
||||
# Mention-only mode: skip channel messages that don't @mention the bot.
|
||||
# DMs (type "D") are always processed.
|
||||
if channel_type_raw != "D":
|
||||
mention_patterns = [
|
||||
f"@{self._bot_username}",
|
||||
f"@{self._bot_user_id}",
|
||||
]
|
||||
has_mention = any(
|
||||
pattern.lower() in message_text.lower()
|
||||
for pattern in mention_patterns
|
||||
)
|
||||
if not has_mention:
|
||||
logger.debug(
|
||||
"Mattermost: skipping non-DM message without @mention (channel=%s)",
|
||||
channel_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve sender info.
|
||||
sender_id = post.get("user_id", "")
|
||||
sender_name = data.get("sender_name", "").lstrip("@") or sender_id
|
||||
|
||||
# Thread support: if the post is in a thread, use root_id.
|
||||
thread_id = post.get("root_id") or None
|
||||
|
||||
# Determine message type.
|
||||
file_ids = post.get("file_ids") or []
|
||||
msg_type = MessageType.TEXT
|
||||
if message_text.startswith("/"):
|
||||
msg_type = MessageType.COMMAND
|
||||
|
||||
# Download file attachments immediately (URLs require auth headers
|
||||
# that downstream tools won't have).
|
||||
media_urls: List[str] = []
|
||||
media_types: List[str] = []
|
||||
for fid in file_ids:
|
||||
try:
|
||||
file_info = await self._api_get(f"files/{fid}/info")
|
||||
fname = file_info.get("name", f"file_{fid}")
|
||||
ext = Path(fname).suffix or ""
|
||||
mime = file_info.get("mime_type", "application/octet-stream")
|
||||
|
||||
import aiohttp
|
||||
dl_url = f"{self._base_url}/api/v4/files/{fid}"
|
||||
async with self._session.get(
|
||||
dl_url,
|
||||
headers={"Authorization": f"Bearer {self._token}"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status < 400:
|
||||
file_data = await resp.read()
|
||||
from gateway.platforms.base import cache_image_from_bytes, cache_document_from_bytes
|
||||
if mime.startswith("image/"):
|
||||
local_path = cache_image_from_bytes(file_data, ext or ".png")
|
||||
media_urls.append(local_path)
|
||||
media_types.append(mime)
|
||||
elif mime.startswith("audio/"):
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
local_path = cache_audio_from_bytes(file_data, ext or ".ogg")
|
||||
media_urls.append(local_path)
|
||||
media_types.append(mime)
|
||||
else:
|
||||
local_path = cache_document_from_bytes(file_data, fname)
|
||||
media_urls.append(local_path)
|
||||
media_types.append(mime)
|
||||
else:
|
||||
logger.warning("Mattermost: failed to download file %s: HTTP %s", fid, resp.status)
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: error downloading file %s: %s", fid, exc)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_type=chat_type,
|
||||
user_id=sender_id,
|
||||
user_name=sender_name,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=message_text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=post,
|
||||
message_id=post_id,
|
||||
media_urls=media_urls if media_urls else None,
|
||||
media_types=media_types if media_types else None,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
def _prune_seen(self) -> None:
|
||||
"""Remove expired entries from the dedup cache."""
|
||||
if len(self._seen_posts) < self._SEEN_MAX:
|
||||
return
|
||||
now = time.time()
|
||||
self._seen_posts = {
|
||||
pid: ts
|
||||
for pid, ts in self._seen_posts.items()
|
||||
if now - ts < self._SEEN_TTL
|
||||
}
|
||||
@@ -179,6 +179,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
# Track recently sent message timestamps to prevent echo-back loops
|
||||
# in Note to Self / self-chat mode (mirrors WhatsApp recentlySentIds)
|
||||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
@@ -353,10 +358,26 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Unwrap nested envelope if present
|
||||
envelope_data = envelope.get("envelope", envelope)
|
||||
|
||||
# Filter syncMessage envelopes (sent transcripts, read receipts, etc.)
|
||||
# signal-cli may set syncMessage to null vs omitting it, so check key existence
|
||||
# Handle syncMessage: extract "Note to Self" messages (sent to own account)
|
||||
# while still filtering other sync events (read receipts, typing, etc.)
|
||||
is_note_to_self = False
|
||||
if "syncMessage" in envelope_data:
|
||||
return
|
||||
sync_msg = envelope_data.get("syncMessage")
|
||||
if sync_msg and isinstance(sync_msg, dict):
|
||||
sent_msg = sync_msg.get("sentMessage")
|
||||
if sent_msg and isinstance(sent_msg, dict):
|
||||
dest = sent_msg.get("destinationNumber") or sent_msg.get("destination")
|
||||
sent_ts = sent_msg.get("timestamp")
|
||||
if dest == self._account_normalized:
|
||||
# Check if this is an echo of our own outbound reply
|
||||
if sent_ts and sent_ts in self._recent_sent_timestamps:
|
||||
self._recent_sent_timestamps.discard(sent_ts)
|
||||
return
|
||||
# Genuine user Note to Self — promote to dataMessage
|
||||
is_note_to_self = True
|
||||
envelope_data = {**envelope_data, "dataMessage": sent_msg}
|
||||
if not is_note_to_self:
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = (
|
||||
@@ -371,8 +392,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
logger.debug("Signal: ignoring envelope with no sender")
|
||||
return
|
||||
|
||||
# Self-message filtering — prevent reply loops
|
||||
if self._account_normalized and sender == self._account_normalized:
|
||||
# Self-message filtering — prevent reply loops (but allow Note to Self)
|
||||
if self._account_normalized and sender == self._account_normalized and not is_note_to_self:
|
||||
return
|
||||
|
||||
# Filter stories
|
||||
@@ -457,7 +478,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if any(mt.startswith("audio/") for mt in media_types):
|
||||
msg_type = MessageType.VOICE
|
||||
elif any(mt.startswith("image/") for mt in media_types):
|
||||
msg_type = MessageType.IMAGE
|
||||
msg_type = MessageType.PHOTO
|
||||
|
||||
# Parse timestamp from envelope data (milliseconds since epoch)
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
@@ -498,6 +519,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if not result:
|
||||
return None, ""
|
||||
|
||||
# Handle dict response (signal-cli returns {"data": "base64..."})
|
||||
if isinstance(result, dict):
|
||||
result = result.get("data")
|
||||
if not result:
|
||||
logger.warning("Signal: attachment response missing 'data' key")
|
||||
return None, ""
|
||||
|
||||
# Result is base64-encoded file content
|
||||
raw_data = base64.b64decode(result)
|
||||
ext = _guess_extension(raw_data)
|
||||
@@ -577,9 +605,18 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
result = await self._rpc("send", params)
|
||||
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
def _track_sent_timestamp(self, rpc_result) -> None:
|
||||
"""Record outbound message timestamp for echo-back filtering."""
|
||||
ts = rpc_result.get("timestamp") if isinstance(rpc_result, dict) else None
|
||||
if ts:
|
||||
self._recent_sent_timestamps.add(ts)
|
||||
if len(self._recent_sent_timestamps) > self._max_recent_timestamps:
|
||||
self._recent_sent_timestamps.pop()
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: Dict[str, Any] = {
|
||||
@@ -635,6 +672,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send with attachment failed")
|
||||
|
||||
@@ -665,6 +703,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send document failed")
|
||||
|
||||
|
||||
@@ -789,23 +789,11 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
user_id = command.get("user_id", "")
|
||||
channel_id = command.get("channel_id", "")
|
||||
|
||||
# Map subcommands to gateway commands
|
||||
subcommand_map = {
|
||||
"new": "/reset", "reset": "/reset",
|
||||
"status": "/status", "stop": "/stop",
|
||||
"help": "/help",
|
||||
"model": "/model", "personality": "/personality",
|
||||
"retry": "/retry", "undo": "/undo",
|
||||
"compact": "/compress", "compress": "/compress",
|
||||
"resume": "/resume",
|
||||
"background": "/background",
|
||||
"usage": "/usage",
|
||||
"insights": "/insights",
|
||||
"title": "/title",
|
||||
"reasoning": "/reasoning",
|
||||
"provider": "/provider",
|
||||
"rollback": "/rollback",
|
||||
}
|
||||
# Map subcommands to gateway commands — derived from central registry.
|
||||
# Also keep "compact" as a Slack-specific alias for /compress.
|
||||
from hermes_cli.commands import slack_subcommand_map
|
||||
subcommand_map = slack_subcommand_map()
|
||||
subcommand_map["compact"] = "/compress"
|
||||
first_word = text.split()[0] if text else ""
|
||||
if first_word in subcommand_map:
|
||||
# Preserve arguments after the subcommand
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
"""SMS (Twilio) platform adapter.
|
||||
|
||||
Connects to the Twilio REST API for outbound SMS and runs an aiohttp
|
||||
webhook server to receive inbound messages.
|
||||
|
||||
Shares credentials with the optional telephony skill — same env vars:
|
||||
- TWILIO_ACCOUNT_SID
|
||||
- TWILIO_AUTH_TOKEN
|
||||
- TWILIO_PHONE_NUMBER (E.164 from-number, e.g. +15551234567)
|
||||
|
||||
Gateway-specific env vars:
|
||||
- SMS_WEBHOOK_PORT (default 8080)
|
||||
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
|
||||
- SMS_ALLOW_ALL_USERS (true/false)
|
||||
- SMS_HOME_CHANNEL (phone number for cron delivery)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
|
||||
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
|
||||
DEFAULT_WEBHOOK_PORT = 8080
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:5] + "***" + phone[-4:]
|
||||
|
||||
|
||||
def check_sms_requirements() -> bool:
|
||||
"""Check if SMS adapter dependencies are available."""
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
return bool(os.getenv("TWILIO_ACCOUNT_SID") and os.getenv("TWILIO_AUTH_TOKEN"))
|
||||
|
||||
|
||||
class SmsAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Twilio SMS <-> Hermes gateway adapter.
|
||||
|
||||
Each inbound phone number gets its own Hermes session (multi-tenant).
|
||||
Replies are always sent from the configured TWILIO_PHONE_NUMBER.
|
||||
"""
|
||||
|
||||
MAX_MESSAGE_LENGTH = MAX_SMS_LENGTH
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SMS)
|
||||
self._account_sid: str = os.environ["TWILIO_ACCOUNT_SID"]
|
||||
self._auth_token: str = os.environ["TWILIO_AUTH_TOKEN"]
|
||||
self._from_number: str = os.getenv("TWILIO_PHONE_NUMBER", "")
|
||||
self._webhook_port: int = int(
|
||||
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self._runner = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
def _basic_auth_header(self) -> str:
|
||||
"""Build HTTP Basic auth header value for Twilio."""
|
||||
creds = f"{self._account_sid}:{self._auth_token}"
|
||||
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
|
||||
return f"Basic {encoded}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Required abstract methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
if not self._from_number:
|
||||
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
|
||||
return False
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/webhooks/twilio", self._handle_webhook)
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
|
||||
await site.start()
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
self._running = True
|
||||
|
||||
logger.info(
|
||||
"[sms] Twilio webhook server listening on port %d, from: %s",
|
||||
self._webhook_port,
|
||||
_redact_phone(self._from_number),
|
||||
)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._http_session:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._running = False
|
||||
logger.info("[sms] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
import aiohttp
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted)
|
||||
last_result = SendResult(success=True)
|
||||
|
||||
url = f"{TWILIO_API_BASE}/{self._account_sid}/Messages.json"
|
||||
headers = {
|
||||
"Authorization": self._basic_auth_header(),
|
||||
}
|
||||
|
||||
session = self._http_session or aiohttp.ClientSession()
|
||||
try:
|
||||
for chunk in chunks:
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("From", self._from_number)
|
||||
form_data.add_field("To", chat_id)
|
||||
form_data.add_field("Body", chunk)
|
||||
|
||||
try:
|
||||
async with session.post(url, data=form_data, headers=headers) as resp:
|
||||
body = await resp.json()
|
||||
if resp.status >= 400:
|
||||
error_msg = body.get("message", str(body))
|
||||
logger.error(
|
||||
"[sms] send failed to %s: %s %s",
|
||||
_redact_phone(chat_id),
|
||||
resp.status,
|
||||
error_msg,
|
||||
)
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"Twilio {resp.status}: {error_msg}",
|
||||
)
|
||||
msg_sid = body.get("sid", "")
|
||||
last_result = SendResult(success=True, message_id=msg_sid)
|
||||
except Exception as e:
|
||||
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
finally:
|
||||
# Close session only if we created a fallback (no persistent session)
|
||||
if not self._http_session and session:
|
||||
await session.close()
|
||||
|
||||
return last_result
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SMS-specific formatting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Strip markdown — SMS renders it as literal characters."""
|
||||
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"```[a-z]*\n?", "", content)
|
||||
content = re.sub(r"`(.+?)`", r"\1", content)
|
||||
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
|
||||
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
|
||||
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||
return content.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio webhook handler
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_webhook(self, request) -> "aiohttp.web.Response":
|
||||
from aiohttp import web
|
||||
|
||||
try:
|
||||
raw = await request.read()
|
||||
# Twilio sends form-encoded data, not JSON
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error("[sms] webhook parse error: %s", e)
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Extract fields (parse_qs returns lists)
|
||||
from_number = (form.get("From", [""]))[0].strip()
|
||||
to_number = (form.get("To", [""]))[0].strip()
|
||||
text = (form.get("Body", [""]))[0].strip()
|
||||
message_sid = (form.get("MessageSid", [""]))[0].strip()
|
||||
|
||||
if not from_number or not text:
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
)
|
||||
|
||||
# Ignore messages from our own number (echo prevention)
|
||||
if from_number == self._from_number:
|
||||
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[sms] inbound from %s -> %s: %s",
|
||||
_redact_phone(from_number),
|
||||
_redact_phone(to_number),
|
||||
text[:80],
|
||||
)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=from_number,
|
||||
chat_name=from_number,
|
||||
chat_type="dm",
|
||||
user_id=from_number,
|
||||
user_name=from_number,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=form,
|
||||
message_id=message_sid,
|
||||
)
|
||||
|
||||
# Non-blocking: Twilio expects a fast response
|
||||
asyncio.create_task(self.handle_message(event))
|
||||
|
||||
# Return empty TwiML — we send replies via the REST API, not inline TwiML
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
)
|
||||
+431
-65
@@ -79,8 +79,8 @@ def _escape_mdv2(text: str) -> str:
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
|
||||
|
||||
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
Also removes MarkdownV2 formatting markers so the fallback
|
||||
doesn't show stray syntax characters from format_message conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
@@ -89,6 +89,10 @@ def _strip_mdv2(text: str) -> str:
|
||||
# Remove MarkdownV2 italic markers that format_message converted from *italic*
|
||||
# Use word boundary (\b) to avoid breaking snake_case like my_variable_name
|
||||
cleaned = re.sub(r'(?<!\w)_([^_]+)_(?!\w)', r'\1', cleaned)
|
||||
# Remove MarkdownV2 strikethrough markers (~text~ → text)
|
||||
cleaned = re.sub(r'~([^~]+)~', r'\1', cleaned)
|
||||
# Remove MarkdownV2 spoiler markers (||text|| → text)
|
||||
cleaned = re.sub(r'\|\|([^|]+)\|\|', r'\1', cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
@@ -118,8 +122,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._pending_photo_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._media_group_events: Dict[str, MessageEvent] = {}
|
||||
self._media_group_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Buffer rapid text messages so Telegram client-side splits of long
|
||||
# messages are aggregated into a single MessageEvent.
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
self._polling_error_callback_ref = None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_polling_conflict(error: Exception) -> bool:
|
||||
@@ -130,13 +142,126 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
or "another bot instance is running" in text
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_network_error(error: Exception) -> bool:
|
||||
"""Return True for transient network errors that warrant a reconnect attempt."""
|
||||
name = error.__class__.__name__.lower()
|
||||
if name in ("networkerror", "timedout", "connectionerror"):
|
||||
return True
|
||||
try:
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
if isinstance(error, (NetworkError, TimedOut)):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return isinstance(error, OSError)
|
||||
|
||||
async def _handle_polling_network_error(self, error: Exception) -> None:
|
||||
"""Reconnect polling after a transient network interruption.
|
||||
|
||||
Triggered by NetworkError/TimedOut in the polling error callback, which
|
||||
happen when the host loses connectivity (Mac sleep, WiFi switch, VPN
|
||||
reconnect, etc.). The gateway process stays alive but the long-poll
|
||||
connection silently dies; without this handler the bot never recovers.
|
||||
|
||||
Strategy: exponential back-off (5s, 10s, 20s, 40s, 60s cap) up to
|
||||
MAX_NETWORK_RETRIES attempts, then mark the adapter retryable-fatal so
|
||||
the supervisor restarts the gateway process.
|
||||
"""
|
||||
if self.has_fatal_error:
|
||||
return
|
||||
|
||||
MAX_NETWORK_RETRIES = 10
|
||||
BASE_DELAY = 5
|
||||
MAX_DELAY = 60
|
||||
|
||||
self._polling_network_error_count += 1
|
||||
attempt = self._polling_network_error_count
|
||||
|
||||
if attempt > MAX_NETWORK_RETRIES:
|
||||
message = (
|
||||
"Telegram polling could not reconnect after %d network error retries. "
|
||||
"Restarting gateway." % MAX_NETWORK_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Last error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_network_error", message, retryable=True)
|
||||
await self._notify_fatal_error()
|
||||
return
|
||||
|
||||
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
||||
logger.warning(
|
||||
"[%s] Telegram network error (attempt %d/%d), reconnecting in %ds. Error: %s",
|
||||
self.name, attempt, MAX_NETWORK_RETRIES, delay, error,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
if self._app and self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=False,
|
||||
error_callback=self._polling_error_callback_ref,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Telegram polling resumed after network error (attempt %d)",
|
||||
self.name, attempt,
|
||||
)
|
||||
self._polling_network_error_count = 0
|
||||
except Exception as retry_err:
|
||||
logger.warning("[%s] Telegram polling reconnect failed: %s", self.name, retry_err)
|
||||
# The next network error will trigger another attempt.
|
||||
|
||||
async def _handle_polling_conflict(self, error: Exception) -> None:
|
||||
if self.has_fatal_error and self.fatal_error_code == "telegram_polling_conflict":
|
||||
return
|
||||
# Track consecutive conflicts — transient 409s can occur when a
|
||||
# previous gateway instance hasn't fully released its long-poll
|
||||
# session on Telegram's server (e.g. during --replace handoffs or
|
||||
# systemd Restart=on-failure respawns). Retry a few times before
|
||||
# giving up, so the old session has time to expire.
|
||||
self._polling_conflict_count += 1
|
||||
|
||||
MAX_CONFLICT_RETRIES = 3
|
||||
RETRY_DELAY = 10 # seconds
|
||||
|
||||
if self._polling_conflict_count <= MAX_CONFLICT_RETRIES:
|
||||
logger.warning(
|
||||
"[%s] Telegram polling conflict (%d/%d), will retry in %ds. Error: %s",
|
||||
self.name, self._polling_conflict_count, MAX_CONFLICT_RETRIES,
|
||||
RETRY_DELAY, error,
|
||||
)
|
||||
try:
|
||||
if self._app and self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(RETRY_DELAY)
|
||||
try:
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=False,
|
||||
error_callback=self._polling_error_callback_ref,
|
||||
)
|
||||
logger.info("[%s] Telegram polling resumed after conflict retry %d", self.name, self._polling_conflict_count)
|
||||
self._polling_conflict_count = 0 # reset on success
|
||||
return
|
||||
except Exception as retry_err:
|
||||
logger.warning("[%s] Telegram polling retry failed: %s", self.name, retry_err)
|
||||
# Don't fall through to fatal yet — wait for the next conflict
|
||||
# to trigger another retry attempt (up to MAX_CONFLICT_RETRIES).
|
||||
return
|
||||
|
||||
# Exhausted retries — fatal
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Hermes stopped Telegram polling to avoid endless retry spam. "
|
||||
"Hermes stopped Telegram polling after %d retries. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
% MAX_CONFLICT_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_polling_conflict", message, retryable=False)
|
||||
@@ -202,18 +327,42 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
# Start polling — retry initialize() for transient TLS resets
|
||||
try:
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
except ImportError:
|
||||
NetworkError = TimedOut = OSError # type: ignore[misc,assignment]
|
||||
_max_connect = 3
|
||||
for _attempt in range(_max_connect):
|
||||
try:
|
||||
await self._app.initialize()
|
||||
break
|
||||
except (NetworkError, TimedOut, OSError) as init_err:
|
||||
if _attempt < _max_connect - 1:
|
||||
wait = 2 ** _attempt
|
||||
logger.warning(
|
||||
"[%s] Connect attempt %d/%d failed: %s — retrying in %ds",
|
||||
self.name, _attempt + 1, _max_connect, init_err, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
await self._app.start()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _polling_error_callback(error: Exception) -> None:
|
||||
if not self._looks_like_polling_conflict(error):
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
return
|
||||
if self._polling_error_task and not self._polling_error_task.done():
|
||||
return
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
if self._looks_like_polling_conflict(error):
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
elif self._looks_like_network_error(error):
|
||||
logger.warning("[%s] Telegram network error, scheduling reconnect: %s", self.name, error)
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_network_error(error))
|
||||
else:
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
|
||||
# Store reference for retry use in _handle_polling_conflict
|
||||
self._polling_error_callback_ref = _polling_error_callback
|
||||
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
@@ -222,29 +371,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
# List is derived from the central COMMAND_REGISTRY — adding a new
|
||||
# gateway command there automatically adds it to the Telegram menu.
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
from hermes_cli.commands import telegram_bot_commands
|
||||
await self._bot.set_my_commands([
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("reasoning", "Show or change reasoning effort"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("voice", "Toggle voice reply mode"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
BotCommand(name, desc) for name, desc in telegram_bot_commands()
|
||||
])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
@@ -265,6 +398,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
message = f"Telegram startup failed: {e}"
|
||||
self._set_fatal_error("telegram_connect_error", message, retryable=True)
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -322,36 +457,59 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
if len(chunks) > 1:
|
||||
# truncate_message appends a raw " (1/2)" suffix. Escape the
|
||||
# MarkdownV2-special parentheses so Telegram doesn't reject the
|
||||
# chunk and fall back to plain text.
|
||||
chunks = [
|
||||
re.sub(r" \((\d+)/(\d+)\)$", r" \\(\1/\2\\)", chunk)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
try:
|
||||
from telegram.error import NetworkError as _NetErr
|
||||
except ImportError:
|
||||
_NetErr = OSError # type: ignore[misc,assignment]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
msg = None
|
||||
for _send_attempt in range(3):
|
||||
try:
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
break # success
|
||||
except _NetErr as send_err:
|
||||
if _send_attempt < 2:
|
||||
wait = 2 ** _send_attempt
|
||||
logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s",
|
||||
self.name, _send_attempt + 1, wait, send_err)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
return SendResult(
|
||||
@@ -382,7 +540,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text=formatted,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as fmt_err:
|
||||
# "Message is not modified" is a no-op, not an error
|
||||
if "not modified" in str(fmt_err).lower():
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
# Fallback: retry without markdown formatting
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id),
|
||||
@@ -391,6 +552,46 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
# "Message is not modified" — content identical, treat as success
|
||||
if "not modified" in err_str:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
# Message too long — content exceeded 4096 chars (e.g. during
|
||||
# streaming). Truncate and succeed so the stream consumer can
|
||||
# split the overflow into a new message instead of dying.
|
||||
if "message_too_long" in err_str or "too long" in err_str:
|
||||
truncated = content[: self.MAX_MESSAGE_LENGTH - 20] + "…"
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
text=truncated,
|
||||
)
|
||||
except Exception:
|
||||
pass # best-effort truncation
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
# Flood control / RetryAfter — back off and retry once
|
||||
retry_after = getattr(e, "retry_after", None)
|
||||
if retry_after is not None or "retry after" in err_str:
|
||||
wait = retry_after if retry_after else 1.0
|
||||
logger.warning(
|
||||
"[%s] Telegram flood control, waiting %.1fs",
|
||||
self.name, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
text=content,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as retry_err:
|
||||
logger.error(
|
||||
"[%s] Edit retry failed after flood wait: %s",
|
||||
self.name, retry_err,
|
||||
)
|
||||
return SendResult(success=False, error=str(retry_err))
|
||||
logger.error(
|
||||
"[%s] Failed to edit Telegram message %s: %s",
|
||||
self.name,
|
||||
@@ -455,23 +656,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -490,6 +694,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file natively as a Telegram file attachment."""
|
||||
@@ -501,6 +706,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
msg = await self._bot.send_document(
|
||||
@@ -509,6 +715,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
filename=display_name,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -521,6 +728,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a video natively as a Telegram video message."""
|
||||
@@ -531,12 +739,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(video_path, "rb") as f:
|
||||
msg = await self._bot.send_video(
|
||||
chat_id=int(chat_id),
|
||||
video=f,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -712,14 +922,30 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text = content
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
# Per MarkdownV2 spec, \ and ` inside pre/code must be escaped.
|
||||
def _protect_fenced(m):
|
||||
raw = m.group(0)
|
||||
# Split off opening ``` (with optional language) and closing ```
|
||||
open_end = raw.index('\n') + 1 if '\n' in raw[3:] else 3
|
||||
opening = raw[:open_end]
|
||||
body_and_close = raw[open_end:]
|
||||
body = body_and_close[:-3]
|
||||
body = body.replace('\\', '\\\\').replace('`', '\\`')
|
||||
return _ph(opening + body + '```')
|
||||
|
||||
text = re.sub(
|
||||
r'(```(?:[^\n]*\n)?[\s\S]*?```)',
|
||||
lambda m: _ph(m.group(0)),
|
||||
_protect_fenced,
|
||||
text,
|
||||
)
|
||||
|
||||
# 2) Protect inline code (`...`)
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
# Escape \ inside inline code per MarkdownV2 spec.
|
||||
text = re.sub(
|
||||
r'(`[^`]+`)',
|
||||
lambda m: _ph(m.group(0).replace('\\', '\\\\')),
|
||||
text,
|
||||
)
|
||||
|
||||
# 3) Convert markdown links – escape the display text; inside the URL
|
||||
# only ')' and '\' need escaping per the MarkdownV2 spec.
|
||||
@@ -757,23 +983,89 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text,
|
||||
)
|
||||
|
||||
# 7) Escape remaining special characters in plain text
|
||||
# 7) Convert strikethrough: ~~text~~ → ~text~ (MarkdownV2)
|
||||
text = re.sub(
|
||||
r'~~(.+?)~~',
|
||||
lambda m: _ph(f'~{_escape_mdv2(m.group(1))}~'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 8) Convert spoiler: ||text|| → ||text|| (protect from | escaping)
|
||||
text = re.sub(
|
||||
r'\|\|(.+?)\|\|',
|
||||
lambda m: _ph(f'||{_escape_mdv2(m.group(1))}||'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 9) Convert blockquotes: > at line start → protect > from escaping
|
||||
text = re.sub(
|
||||
r'^(>{1,3}) (.+)$',
|
||||
lambda m: _ph(m.group(1) + ' ' + _escape_mdv2(m.group(2))),
|
||||
text,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
# 10) Escape remaining special characters in plain text
|
||||
text = _escape_mdv2(text)
|
||||
|
||||
# 8) Restore placeholders in reverse insertion order so that
|
||||
# 11) Restore placeholders in reverse insertion order so that
|
||||
# nested references (a placeholder inside another) resolve correctly.
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
# 12) Safety net: escape unescaped ( ) { } that slipped through
|
||||
# placeholder processing. Split the text into code/non-code
|
||||
# segments so we never touch content inside ``` or ` spans.
|
||||
_code_split = re.split(r'(```[\s\S]*?```|`[^`]+`)', text)
|
||||
_safe_parts = []
|
||||
for _idx, _seg in enumerate(_code_split):
|
||||
if _idx % 2 == 1:
|
||||
# Inside code span/block — leave untouched
|
||||
_safe_parts.append(_seg)
|
||||
else:
|
||||
# Outside code — escape bare ( ) { }
|
||||
def _esc_bare(m, _seg=_seg):
|
||||
s = m.start()
|
||||
ch = m.group(0)
|
||||
# Already escaped
|
||||
if s > 0 and _seg[s - 1] == '\\':
|
||||
return ch
|
||||
# ( that opens a MarkdownV2 link [text](url)
|
||||
if ch == '(' and s > 0 and _seg[s - 1] == ']':
|
||||
return ch
|
||||
# ) that closes a link URL
|
||||
if ch == ')':
|
||||
before = _seg[:s]
|
||||
if '](http' in before or '](' in before:
|
||||
# Check depth
|
||||
depth = 0
|
||||
for j in range(s - 1, max(s - 2000, -1), -1):
|
||||
if _seg[j] == '(':
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
if j > 0 and _seg[j - 1] == ']':
|
||||
return ch
|
||||
break
|
||||
elif _seg[j] == ')':
|
||||
depth += 1
|
||||
return '\\' + ch
|
||||
_safe_parts.append(re.sub(r'[(){}]', _esc_bare, _seg))
|
||||
text = ''.join(_safe_parts)
|
||||
|
||||
return text
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
"""Handle incoming text messages.
|
||||
|
||||
Telegram clients split long messages into multiple updates. Buffer
|
||||
rapid successive text messages from the same user/chat and aggregate
|
||||
them into a single MessageEvent before dispatching.
|
||||
"""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
self._enqueue_text_event(event)
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
@@ -818,10 +1110,75 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text message aggregation (handles Telegram client-side splits)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
from gateway.session import build_session_key
|
||||
return build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
)
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
When Telegram splits a long user message into multiple updates,
|
||||
they arrive within a few hundred milliseconds. This method
|
||||
concatenates them and waits for a short quiet period before
|
||||
dispatching the combined message.
|
||||
"""
|
||||
key = self._text_batch_key(event)
|
||||
existing = self._pending_text_batches.get(key)
|
||||
if existing is None:
|
||||
self._pending_text_batches[key] = event
|
||||
else:
|
||||
# Append text from the follow-up chunk
|
||||
if event.text:
|
||||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||||
# Merge any media that might be attached
|
||||
if event.media_urls:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
|
||||
# Cancel any pending flush and restart the timer
|
||||
prior_task = self._pending_text_batch_tasks.get(key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||||
self._flush_text_batch(key)
|
||||
)
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Wait for the quiet period then dispatch the aggregated text."""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
await asyncio.sleep(self._text_batch_delay_seconds)
|
||||
event = self._pending_text_batches.pop(key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info(
|
||||
"[Telegram] Flushing text batch %s (%d chars)",
|
||||
key, len(event.text or ""),
|
||||
)
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||||
self._pending_text_batch_tasks.pop(key, None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Photo batching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
|
||||
"""Return a batching key for Telegram photos/albums."""
|
||||
from gateway.session import build_session_key
|
||||
session_key = build_session_key(event.source)
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
return f"{session_key}:album:{media_group_id}"
|
||||
@@ -1155,11 +1512,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
# Extract reply context if this message is a reply
|
||||
reply_to_id = None
|
||||
reply_to_text = None
|
||||
if message.reply_to_message:
|
||||
reply_to_id = str(message.reply_to_message.message_id)
|
||||
reply_to_text = message.reply_to_message.text or message.reply_to_message.caption or None
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.message_id),
|
||||
reply_to_message_id=reply_to_id,
|
||||
reply_to_text=reply_to_text,
|
||||
timestamp=message.date,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,557 @@
|
||||
"""Generic webhook platform adapter.
|
||||
|
||||
Runs an aiohttp HTTP server that receives webhook POSTs from external
|
||||
services (GitHub, GitLab, JIRA, Stripe, etc.), validates HMAC signatures,
|
||||
transforms payloads into agent prompts, and routes responses back to the
|
||||
source or to another configured platform.
|
||||
|
||||
Configuration lives in config.yaml under platforms.webhook.extra.routes.
|
||||
Each route defines:
|
||||
- events: which event types to accept (header-based filtering)
|
||||
- secret: HMAC secret for signature validation (REQUIRED)
|
||||
- prompt: template string formatted with the webhook payload
|
||||
- skills: optional list of skills to load for the agent
|
||||
- deliver: where to send the response (github_comment, telegram, etc.)
|
||||
- deliver_extra: additional delivery config (repo, pr_number, chat_id)
|
||||
|
||||
Security:
|
||||
- HMAC secret is required per route (validated at startup)
|
||||
- Rate limiting per route (fixed-window, configurable)
|
||||
- Idempotency cache prevents duplicate agent runs on webhook retries
|
||||
- Body size limits checked before reading payload
|
||||
- Set secret to "INSECURE_NO_AUTH" to skip validation (testing only)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
web = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HOST = "0.0.0.0"
|
||||
DEFAULT_PORT = 8644
|
||||
_INSECURE_NO_AUTH = "INSECURE_NO_AUTH"
|
||||
|
||||
|
||||
def check_webhook_requirements() -> bool:
|
||||
"""Check if webhook adapter dependencies are available."""
|
||||
return AIOHTTP_AVAILABLE
|
||||
|
||||
|
||||
class WebhookAdapter(BasePlatformAdapter):
|
||||
"""Generic webhook receiver that triggers agent runs from HTTP POSTs."""
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WEBHOOK)
|
||||
self._host: str = config.extra.get("host", DEFAULT_HOST)
|
||||
self._port: int = int(config.extra.get("port", DEFAULT_PORT))
|
||||
self._global_secret: str = config.extra.get("secret", "")
|
||||
self._routes: Dict[str, dict] = config.extra.get("routes", {})
|
||||
self._runner = None
|
||||
|
||||
# Delivery info keyed by session chat_id — consumed by send()
|
||||
self._delivery_info: Dict[str, dict] = {}
|
||||
|
||||
# Reference to gateway runner for cross-platform delivery (set externally)
|
||||
self.gateway_runner = None
|
||||
|
||||
# Idempotency: TTL cache of recently processed delivery IDs.
|
||||
# Prevents duplicate agent runs when webhook providers retry.
|
||||
self._seen_deliveries: Dict[str, float] = {}
|
||||
self._idempotency_ttl: int = 3600 # 1 hour
|
||||
|
||||
# Rate limiting: per-route timestamps in a fixed window.
|
||||
self._rate_counts: Dict[str, List[float]] = {}
|
||||
self._rate_limit: int = int(config.extra.get("rate_limit", 30)) # per minute
|
||||
|
||||
# Body size limit (auth-before-body pattern)
|
||||
self._max_body_bytes: int = int(
|
||||
config.extra.get("max_body_bytes", 1_048_576)
|
||||
) # 1MB
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
# Validate routes at startup — secret is required per route
|
||||
for name, route in self._routes.items():
|
||||
secret = route.get("secret", self._global_secret)
|
||||
if not secret:
|
||||
raise ValueError(
|
||||
f"[webhook] Route '{name}' has no HMAC secret. "
|
||||
f"Set 'secret' on the route or globally. "
|
||||
f"For testing without auth, set secret to '{_INSECURE_NO_AUTH}'."
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", self._handle_webhook)
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self._host, self._port)
|
||||
await site.start()
|
||||
self._mark_connected()
|
||||
|
||||
route_names = ", ".join(self._routes.keys()) or "(none configured)"
|
||||
logger.info(
|
||||
"[webhook] Listening on %s:%d — routes: %s",
|
||||
self._host,
|
||||
self._port,
|
||||
route_names,
|
||||
)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._mark_disconnected()
|
||||
logger.info("[webhook] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Deliver the agent's response to the configured destination.
|
||||
|
||||
chat_id is ``webhook:{route}:{delivery_id}`` — we pop the delivery
|
||||
info stored during webhook receipt so it doesn't leak memory.
|
||||
"""
|
||||
delivery = self._delivery_info.pop(chat_id, {})
|
||||
deliver_type = delivery.get("deliver", "log")
|
||||
|
||||
if deliver_type == "log":
|
||||
logger.info("[webhook] Response for %s: %s", chat_id, content[:200])
|
||||
return SendResult(success=True)
|
||||
|
||||
if deliver_type == "github_comment":
|
||||
return await self._deliver_github_comment(content, delivery)
|
||||
|
||||
# Cross-platform delivery (telegram, discord, etc.)
|
||||
if self.gateway_runner and deliver_type in (
|
||||
"telegram",
|
||||
"discord",
|
||||
"slack",
|
||||
"signal",
|
||||
"sms",
|
||||
):
|
||||
return await self._deliver_cross_platform(
|
||||
deliver_type, content, delivery
|
||||
)
|
||||
|
||||
logger.warning("[webhook] Unknown deliver type: %s", deliver_type)
|
||||
return SendResult(
|
||||
success=False, error=f"Unknown deliver type: {deliver_type}"
|
||||
)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "webhook"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_health(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /health — simple health check."""
|
||||
return web.json_response({"status": "ok", "platform": "webhook"})
|
||||
|
||||
async def _handle_webhook(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /webhooks/{route_name} — receive and process a webhook event."""
|
||||
route_name = request.match_info.get("route_name", "")
|
||||
route_config = self._routes.get(route_name)
|
||||
|
||||
if not route_config:
|
||||
return web.json_response(
|
||||
{"error": f"Unknown route: {route_name}"}, status=404
|
||||
)
|
||||
|
||||
# ── Auth-before-body ─────────────────────────────────────
|
||||
# Check Content-Length before reading the full payload.
|
||||
content_length = request.content_length or 0
|
||||
if content_length > self._max_body_bytes:
|
||||
return web.json_response(
|
||||
{"error": "Payload too large"}, status=413
|
||||
)
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────
|
||||
now = time.time()
|
||||
window = self._rate_counts.setdefault(route_name, [])
|
||||
window[:] = [t for t in window if now - t < 60]
|
||||
if len(window) >= self._rate_limit:
|
||||
return web.json_response(
|
||||
{"error": "Rate limit exceeded"}, status=429
|
||||
)
|
||||
window.append(now)
|
||||
|
||||
# Read body
|
||||
try:
|
||||
raw_body = await request.read()
|
||||
except Exception as e:
|
||||
logger.error("[webhook] Failed to read body: %s", e)
|
||||
return web.json_response({"error": "Bad request"}, status=400)
|
||||
|
||||
# Validate HMAC signature (skip for INSECURE_NO_AUTH testing mode)
|
||||
secret = route_config.get("secret", self._global_secret)
|
||||
if secret and secret != _INSECURE_NO_AUTH:
|
||||
if not self._validate_signature(request, raw_body, secret):
|
||||
logger.warning(
|
||||
"[webhook] Invalid signature for route %s", route_name
|
||||
)
|
||||
return web.json_response(
|
||||
{"error": "Invalid signature"}, status=401
|
||||
)
|
||||
|
||||
# Parse payload
|
||||
try:
|
||||
payload = json.loads(raw_body)
|
||||
except json.JSONDecodeError:
|
||||
# Try form-encoded as fallback
|
||||
try:
|
||||
import urllib.parse
|
||||
|
||||
payload = dict(
|
||||
urllib.parse.parse_qsl(raw_body.decode("utf-8"))
|
||||
)
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Cannot parse body"}, status=400
|
||||
)
|
||||
|
||||
# Check event type filter
|
||||
event_type = (
|
||||
request.headers.get("X-GitHub-Event", "")
|
||||
or request.headers.get("X-GitLab-Event", "")
|
||||
or payload.get("event_type", "")
|
||||
or "unknown"
|
||||
)
|
||||
allowed_events = route_config.get("events", [])
|
||||
if allowed_events and event_type not in allowed_events:
|
||||
logger.debug(
|
||||
"[webhook] Ignoring event %s for route %s (allowed: %s)",
|
||||
event_type,
|
||||
route_name,
|
||||
allowed_events,
|
||||
)
|
||||
return web.json_response(
|
||||
{"status": "ignored", "event": event_type}
|
||||
)
|
||||
|
||||
# Format prompt from template
|
||||
prompt_template = route_config.get("prompt", "")
|
||||
prompt = self._render_prompt(
|
||||
prompt_template, payload, event_type, route_name
|
||||
)
|
||||
|
||||
# Inject skill content if configured.
|
||||
# We call build_skill_invocation_message() directly rather than
|
||||
# using /skill-name slash commands — the gateway's command parser
|
||||
# would intercept those and break the flow.
|
||||
skills = route_config.get("skills", [])
|
||||
if skills:
|
||||
try:
|
||||
from agent.skill_commands import (
|
||||
build_skill_invocation_message,
|
||||
get_skill_commands,
|
||||
)
|
||||
|
||||
skill_cmds = get_skill_commands()
|
||||
for skill_name in skills:
|
||||
cmd_key = f"/{skill_name}"
|
||||
if cmd_key in skill_cmds:
|
||||
skill_content = build_skill_invocation_message(
|
||||
cmd_key, user_instruction=prompt
|
||||
)
|
||||
if skill_content:
|
||||
prompt = skill_content
|
||||
break # Load the first matching skill
|
||||
else:
|
||||
logger.warning(
|
||||
"[webhook] Skill '%s' not found", skill_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[webhook] Skill loading failed: %s", e)
|
||||
|
||||
# Build a unique delivery ID
|
||||
delivery_id = request.headers.get(
|
||||
"X-GitHub-Delivery",
|
||||
request.headers.get("X-Request-ID", str(int(time.time() * 1000))),
|
||||
)
|
||||
|
||||
# ── Idempotency ─────────────────────────────────────────
|
||||
# Skip duplicate deliveries (webhook retries).
|
||||
now = time.time()
|
||||
# Prune expired entries
|
||||
self._seen_deliveries = {
|
||||
k: v
|
||||
for k, v in self._seen_deliveries.items()
|
||||
if now - v < self._idempotency_ttl
|
||||
}
|
||||
if delivery_id in self._seen_deliveries:
|
||||
logger.info(
|
||||
"[webhook] Skipping duplicate delivery %s", delivery_id
|
||||
)
|
||||
return web.json_response(
|
||||
{"status": "duplicate", "delivery_id": delivery_id},
|
||||
status=200,
|
||||
)
|
||||
self._seen_deliveries[delivery_id] = now
|
||||
|
||||
# Use delivery_id in session key so concurrent webhooks on the
|
||||
# same route get independent agent runs (not queued/interrupted).
|
||||
session_chat_id = f"webhook:{route_name}:{delivery_id}"
|
||||
|
||||
# Store delivery info for send() — consumed (popped) on delivery
|
||||
deliver_config = {
|
||||
"deliver": route_config.get("deliver", "log"),
|
||||
"deliver_extra": self._render_delivery_extra(
|
||||
route_config.get("deliver_extra", {}), payload
|
||||
),
|
||||
"payload": payload,
|
||||
}
|
||||
self._delivery_info[session_chat_id] = deliver_config
|
||||
|
||||
# Build source and event
|
||||
source = self.build_source(
|
||||
chat_id=session_chat_id,
|
||||
chat_name=f"webhook/{route_name}",
|
||||
chat_type="webhook",
|
||||
user_id=f"webhook:{route_name}",
|
||||
user_name=route_name,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text=prompt,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=payload,
|
||||
message_id=delivery_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[webhook] %s event=%s route=%s prompt_len=%d delivery=%s",
|
||||
request.method,
|
||||
event_type,
|
||||
route_name,
|
||||
len(prompt),
|
||||
delivery_id,
|
||||
)
|
||||
|
||||
# Non-blocking — return 202 Accepted immediately
|
||||
asyncio.create_task(self.handle_message(event))
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "accepted",
|
||||
"route": route_name,
|
||||
"event": event_type,
|
||||
"delivery_id": delivery_id,
|
||||
},
|
||||
status=202,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Signature validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_signature(
|
||||
self, request: "web.Request", body: bytes, secret: str
|
||||
) -> bool:
|
||||
"""Validate webhook signature (GitHub, GitLab, generic HMAC-SHA256)."""
|
||||
# GitHub: X-Hub-Signature-256 = sha256=<hex>
|
||||
gh_sig = request.headers.get("X-Hub-Signature-256", "")
|
||||
if gh_sig:
|
||||
expected = "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(gh_sig, expected)
|
||||
|
||||
# GitLab: X-Gitlab-Token = <plain secret>
|
||||
gl_token = request.headers.get("X-Gitlab-Token", "")
|
||||
if gl_token:
|
||||
return hmac.compare_digest(gl_token, secret)
|
||||
|
||||
# Generic: X-Webhook-Signature = <hex HMAC-SHA256>
|
||||
generic_sig = request.headers.get("X-Webhook-Signature", "")
|
||||
if generic_sig:
|
||||
expected = hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(generic_sig, expected)
|
||||
|
||||
# No recognised signature header but secret is configured → reject
|
||||
logger.debug(
|
||||
"[webhook] Secret configured but no signature header found"
|
||||
)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_prompt(
|
||||
self,
|
||||
template: str,
|
||||
payload: dict,
|
||||
event_type: str,
|
||||
route_name: str,
|
||||
) -> str:
|
||||
"""Render a prompt template with the webhook payload.
|
||||
|
||||
Supports dot-notation access into nested dicts:
|
||||
``{pull_request.title}`` → ``payload["pull_request"]["title"]``
|
||||
"""
|
||||
if not template:
|
||||
truncated = json.dumps(payload, indent=2)[:4000]
|
||||
return (
|
||||
f"Webhook event '{event_type}' on route "
|
||||
f"'{route_name}':\n\n```json\n{truncated}\n```"
|
||||
)
|
||||
|
||||
def _resolve(match: re.Match) -> str:
|
||||
key = match.group(1)
|
||||
value: Any = payload
|
||||
for part in key.split("."):
|
||||
if isinstance(value, dict):
|
||||
value = value.get(part, f"{{{key}}}")
|
||||
else:
|
||||
return f"{{{key}}}"
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, indent=2)[:2000]
|
||||
return str(value)
|
||||
|
||||
return re.sub(r"\{([a-zA-Z0-9_.]+)\}", _resolve, template)
|
||||
|
||||
def _render_delivery_extra(
|
||||
self, extra: dict, payload: dict
|
||||
) -> dict:
|
||||
"""Render delivery_extra template values with payload data."""
|
||||
rendered: Dict[str, Any] = {}
|
||||
for key, value in extra.items():
|
||||
if isinstance(value, str):
|
||||
rendered[key] = self._render_prompt(value, payload, "", "")
|
||||
else:
|
||||
rendered[key] = value
|
||||
return rendered
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response delivery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _deliver_github_comment(
|
||||
self, content: str, delivery: dict
|
||||
) -> SendResult:
|
||||
"""Post agent response as a GitHub PR/issue comment via ``gh`` CLI."""
|
||||
extra = delivery.get("deliver_extra", {})
|
||||
repo = extra.get("repo", "")
|
||||
pr_number = extra.get("pr_number", "")
|
||||
|
||||
if not repo or not pr_number:
|
||||
logger.error(
|
||||
"[webhook] github_comment delivery missing repo or pr_number"
|
||||
)
|
||||
return SendResult(
|
||||
success=False, error="Missing repo or pr_number"
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"comment",
|
||||
str(pr_number),
|
||||
"--repo",
|
||||
repo,
|
||||
"--body",
|
||||
content,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(
|
||||
"[webhook] Posted comment on %s#%s", repo, pr_number
|
||||
)
|
||||
return SendResult(success=True)
|
||||
else:
|
||||
logger.error(
|
||||
"[webhook] gh pr comment failed: %s", result.stderr
|
||||
)
|
||||
return SendResult(success=False, error=result.stderr)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[webhook] 'gh' CLI not found — install GitHub CLI for "
|
||||
"github_comment delivery"
|
||||
)
|
||||
return SendResult(
|
||||
success=False, error="gh CLI not installed"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[webhook] github_comment delivery error: %s", e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _deliver_cross_platform(
|
||||
self, platform_name: str, content: str, delivery: dict
|
||||
) -> SendResult:
|
||||
"""Route response to another platform (telegram, discord, etc.)."""
|
||||
if not self.gateway_runner:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="No gateway runner for cross-platform delivery",
|
||||
)
|
||||
|
||||
try:
|
||||
target_platform = Platform(platform_name)
|
||||
except ValueError:
|
||||
return SendResult(
|
||||
success=False, error=f"Unknown platform: {platform_name}"
|
||||
)
|
||||
|
||||
adapter = self.gateway_runner.adapters.get(target_platform)
|
||||
if not adapter:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"Platform {platform_name} not connected",
|
||||
)
|
||||
|
||||
# Use home channel if no specific chat_id in deliver_extra
|
||||
extra = delivery.get("deliver_extra", {})
|
||||
chat_id = extra.get("chat_id", "")
|
||||
if not chat_id:
|
||||
home = self.gateway_runner.config.get_home_channel(target_platform)
|
||||
if home:
|
||||
chat_id = home.chat_id
|
||||
else:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"No chat_id or home channel for {platform_name}",
|
||||
)
|
||||
|
||||
return await adapter.send(chat_id, content)
|
||||
@@ -136,6 +136,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"session_path",
|
||||
get_hermes_home() / "whatsapp" / "session"
|
||||
))
|
||||
self._reply_prefix: Optional[str] = config.extra.get("reply_prefix")
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
@@ -181,9 +182,31 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if bridge is already running and connected
|
||||
import aiohttp
|
||||
import asyncio
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
bridge_status = data.get("status", "unknown")
|
||||
if bridge_status == "connected":
|
||||
print(f"[{self.name}] Using existing bridge (status: {bridge_status})")
|
||||
self._mark_connected()
|
||||
self._bridge_process = None # Not managed by us
|
||||
asyncio.create_task(self._poll_messages())
|
||||
return True
|
||||
else:
|
||||
print(f"[{self.name}] Bridge found but not connected (status: {bridge_status}), restarting")
|
||||
except Exception:
|
||||
pass # Bridge not running, start a new one
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import asyncio
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
@@ -193,6 +216,14 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_log = self._session_path.parent / "bridge.log"
|
||||
bridge_log_fh = open(self._bridge_log, "a")
|
||||
self._bridge_log_fh = bridge_log_fh
|
||||
|
||||
# Build bridge subprocess environment.
|
||||
# Pass WHATSAPP_REPLY_PREFIX from config.yaml so the Node bridge
|
||||
# can use it without the user needing to set a separate env var.
|
||||
bridge_env = os.environ.copy()
|
||||
if self._reply_prefix is not None:
|
||||
bridge_env["WHATSAPP_REPLY_PREFIX"] = self._reply_prefix
|
||||
|
||||
self._bridge_process = subprocess.Popen(
|
||||
[
|
||||
"node",
|
||||
@@ -204,6 +235,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
stdout=bridge_log_fh,
|
||||
stderr=bridge_log_fh,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
env=bridge_env,
|
||||
)
|
||||
|
||||
# Wait for the bridge to connect to WhatsApp.
|
||||
@@ -222,7 +254,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -254,7 +286,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -274,7 +306,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
@@ -292,6 +324,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
pass
|
||||
self._bridge_log_fh = None
|
||||
|
||||
async def _check_managed_bridge_exit(self) -> Optional[str]:
|
||||
"""Return a fatal error message if the managed bridge child exited."""
|
||||
if self._bridge_process is None:
|
||||
return None
|
||||
|
||||
returncode = self._bridge_process.poll()
|
||||
if returncode is None:
|
||||
return None
|
||||
|
||||
message = f"WhatsApp bridge process exited unexpectedly (code {returncode})."
|
||||
if not self.has_fatal_error:
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_bridge_exited", message, retryable=True)
|
||||
self._close_bridge_log()
|
||||
await self._notify_fatal_error()
|
||||
return self.fatal_error_message or message
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop the WhatsApp bridge and clean up any orphaned processes."""
|
||||
if self._bridge_process:
|
||||
@@ -316,11 +365,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
else:
|
||||
# Bridge was not started by us, don't kill it
|
||||
print(f"[{self.name}] Disconnecting (external bridge left running)")
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
self._running = False
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
@@ -335,6 +384,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -348,7 +400,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
f"http://127.0.0.1:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
@@ -380,11 +432,14 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
f"http://127.0.0.1:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
@@ -411,6 +466,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
@@ -429,7 +487,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
f"http://127.0.0.1:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
@@ -499,13 +557,15 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
if await self._check_managed_bridge_exit():
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
f"http://127.0.0.1:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
@@ -516,13 +576,15 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
if await self._check_managed_bridge_exit():
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -546,10 +608,14 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
while self._running:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
f"http://127.0.0.1:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -561,6 +627,10 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@@ -611,6 +681,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to cache image: {e}", flush=True)
|
||||
cached_urls.append(url)
|
||||
media_types.append("image/jpeg")
|
||||
elif msg_type == MessageType.PHOTO and os.path.isabs(url):
|
||||
# Local file path — bridge already downloaded the image
|
||||
cached_urls.append(url)
|
||||
media_types.append("image/jpeg")
|
||||
print(f"[{self.name}] Using bridge-cached image: {url}", flush=True)
|
||||
elif msg_type == MessageType.VOICE and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
cached_path = await cache_audio_from_url(url, ext=".ogg")
|
||||
@@ -637,4 +712,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
+1515
-288
File diff suppressed because it is too large
Load Diff
+176
-22
@@ -8,9 +8,11 @@ Handles:
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
@@ -19,6 +21,41 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PII redaction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
||||
|
||||
|
||||
def _hash_id(value: str) -> str:
|
||||
"""Deterministic 12-char hex hash of an identifier."""
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
def _hash_sender_id(value: str) -> str:
|
||||
"""Hash a sender ID to ``user_<12hex>``."""
|
||||
return f"user_{_hash_id(value)}"
|
||||
|
||||
|
||||
def _hash_chat_id(value: str) -> str:
|
||||
"""Hash the numeric portion of a chat ID, preserving platform prefix.
|
||||
|
||||
``telegram:12345`` → ``telegram:<hash>``
|
||||
``12345`` → ``<hash>``
|
||||
"""
|
||||
colon = value.find(":")
|
||||
if colon > 0:
|
||||
prefix = value[:colon]
|
||||
return f"{prefix}:{_hash_id(value[colon + 1:])}"
|
||||
return _hash_id(value)
|
||||
|
||||
|
||||
def _looks_like_phone(value: str) -> bool:
|
||||
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
||||
return bool(_PHONE_RE.match(value.strip()))
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
@@ -146,7 +183,21 @@ class SessionContext:
|
||||
}
|
||||
|
||||
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
_PII_SAFE_PLATFORMS = frozenset({
|
||||
Platform.WHATSAPP,
|
||||
Platform.SIGNAL,
|
||||
Platform.TELEGRAM,
|
||||
})
|
||||
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
||||
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
||||
and the LLM needs the real ID to tag users."""
|
||||
|
||||
|
||||
def build_session_context_prompt(
|
||||
context: SessionContext,
|
||||
*,
|
||||
redact_pii: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
@@ -154,7 +205,15 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
- Where it can deliver scheduled task outputs
|
||||
|
||||
When *redact_pii* is True **and** the source platform is in
|
||||
``_PII_SAFE_PLATFORMS``, phone numbers are stripped and user/chat IDs
|
||||
are replaced with deterministic hashes before being sent to the LLM.
|
||||
Platforms like Discord are excluded because mentions need real IDs.
|
||||
Routing still uses the original values (they stay in SessionSource).
|
||||
"""
|
||||
# Only apply redaction on platforms where IDs aren't needed for mentions
|
||||
redact_pii = redact_pii and context.source.platform in _PII_SAFE_PLATFORMS
|
||||
lines = [
|
||||
"## Current Session Context",
|
||||
"",
|
||||
@@ -165,7 +224,25 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
# Build a description that respects PII redaction
|
||||
src = context.source
|
||||
if redact_pii:
|
||||
# Build a safe description without raw IDs
|
||||
_uname = src.user_name or (
|
||||
_hash_sender_id(src.user_id) if src.user_id else "user"
|
||||
)
|
||||
_cname = src.chat_name or _hash_chat_id(src.chat_id)
|
||||
if src.chat_type == "dm":
|
||||
desc = f"DM with {_uname}"
|
||||
elif src.chat_type == "group":
|
||||
desc = f"group: {_cname}"
|
||||
elif src.chat_type == "channel":
|
||||
desc = f"channel: {_cname}"
|
||||
else:
|
||||
desc = _cname
|
||||
else:
|
||||
desc = src.description
|
||||
lines.append(f"**Source:** {platform_name} ({desc})")
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
@@ -175,7 +252,10 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
if context.source.user_name:
|
||||
lines.append(f"**User:** {context.source.user_name}")
|
||||
elif context.source.user_id:
|
||||
lines.append(f"**User ID:** {context.source.user_id}")
|
||||
uid = context.source.user_id
|
||||
if redact_pii:
|
||||
uid = _hash_sender_id(uid)
|
||||
lines.append(f"**User ID:** {uid}")
|
||||
|
||||
# Platform-specific behavioral notes
|
||||
if context.source.platform == Platform.SLACK:
|
||||
@@ -210,7 +290,8 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
hc_id = _hash_chat_id(home.chat_id) if redact_pii else home.chat_id
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {hc_id})")
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
@@ -220,7 +301,10 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append("- `\"origin\"` → Local output (saved to files)")
|
||||
else:
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
|
||||
_origin_label = context.source.chat_name or (
|
||||
_hash_chat_id(context.source.chat_id) if redact_pii else context.source.chat_id
|
||||
)
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({_origin_label})")
|
||||
|
||||
# Local always available
|
||||
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
||||
@@ -259,7 +343,11 @@ class SessionEntry:
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
cost_status: str = "unknown"
|
||||
|
||||
# Last API-reported prompt tokens (for accurate compression pre-check)
|
||||
last_prompt_tokens: int = 0
|
||||
@@ -267,6 +355,8 @@ class SessionEntry:
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
auto_reset_reason: Optional[str] = None # "idle" or "daily"
|
||||
reset_had_activity: bool = False # whether the expired session had any messages
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
@@ -279,8 +369,12 @@ class SessionEntry:
|
||||
"chat_type": self.chat_type,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"cost_status": self.cost_status,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -310,12 +404,16 @@ class SessionEntry:
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
cache_read_tokens=data.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=data.get("cache_write_tokens", 0),
|
||||
total_tokens=data.get("total_tokens", 0),
|
||||
last_prompt_tokens=data.get("last_prompt_tokens", 0),
|
||||
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
||||
cost_status=data.get("cost_status", "unknown"),
|
||||
)
|
||||
|
||||
|
||||
def build_session_key(source: SessionSource) -> str:
|
||||
def build_session_key(source: SessionSource, group_sessions_per_user: bool = True) -> str:
|
||||
"""Build a deterministic session key from a message source.
|
||||
|
||||
This is the single source of truth for session key construction.
|
||||
@@ -328,7 +426,11 @@ def build_session_key(source: SessionSource) -> str:
|
||||
|
||||
Group/channel rules:
|
||||
- chat_id identifies the parent group/channel.
|
||||
- user_id/user_id_alt isolates participants within that parent chat when available when
|
||||
``group_sessions_per_user`` is enabled.
|
||||
- thread_id differentiates threads within that parent chat.
|
||||
- Without participant identifiers, or when isolation is disabled, messages fall back to one
|
||||
shared session per chat.
|
||||
- Without identifiers, messages fall back to one session per platform/chat_type.
|
||||
"""
|
||||
platform = source.platform.value
|
||||
@@ -340,13 +442,18 @@ def build_session_key(source: SessionSource) -> str:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.thread_id}"
|
||||
return f"agent:main:{platform}:dm"
|
||||
|
||||
participant_id = source.user_id_alt or source.user_id
|
||||
key_parts = ["agent:main", platform, source.chat_type]
|
||||
|
||||
if source.chat_id:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
key_parts.append(source.chat_id)
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}"
|
||||
key_parts.append(source.thread_id)
|
||||
if group_sessions_per_user and participant_id:
|
||||
key_parts.append(str(participant_id))
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
class SessionStore:
|
||||
@@ -425,7 +532,10 @@ class SessionStore:
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
return build_session_key(
|
||||
source,
|
||||
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
|
||||
)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
@@ -465,16 +575,19 @@ class SessionStore:
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> Optional[str]:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
Returns the reset reason ("idle" or "daily") if a reset is needed,
|
||||
or None if the session is still valid.
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
session_key = self._generate_session_key(source)
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
return None
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
@@ -482,14 +595,14 @@ class SessionStore:
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
return None
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
return "idle"
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
@@ -502,9 +615,9 @@ class SessionStore:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
return "daily"
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
@@ -546,7 +659,8 @@ class SessionStore:
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
@@ -555,6 +669,9 @@ class SessionStore:
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
auto_reset_reason = reset_reason
|
||||
# Track whether the expired session had any real conversation
|
||||
reset_had_activity = entry.total_tokens > 0
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
@@ -563,6 +680,8 @@ class SessionStore:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
auto_reset_reason = None
|
||||
reset_had_activity = False
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -577,6 +696,8 @@ class SessionStore:
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
auto_reset_reason=auto_reset_reason,
|
||||
reset_had_activity=reset_had_activity,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
@@ -600,8 +721,15 @@ class SessionStore:
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
last_prompt_tokens: int = None,
|
||||
model: str = None,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
@@ -611,15 +739,35 @@ class SessionStore:
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
entry.input_tokens
|
||||
+ entry.output_tokens
|
||||
+ entry.cache_read_tokens
|
||||
+ entry.cache_write_tokens
|
||||
)
|
||||
self._save()
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
entry.session_id, input_tokens, output_tokens,
|
||||
entry.session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
estimated_cost_usd=estimated_cost_usd,
|
||||
cost_status=cost_status,
|
||||
cost_source=cost_source,
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -809,7 +957,13 @@ class SessionStore:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
messages.append(json.loads(line))
|
||||
try:
|
||||
messages.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping corrupt line in transcript %s: %s",
|
||||
session_id, line[:120],
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
+60
-6
@@ -83,11 +83,30 @@ def _looks_like_gateway_process(pid: int) -> bool:
|
||||
"""Return True when the live PID still looks like the Hermes gateway."""
|
||||
cmdline = _read_process_cmdline(pid)
|
||||
if not cmdline:
|
||||
# If we cannot inspect the process, fall back to the liveness check.
|
||||
return True
|
||||
return False
|
||||
|
||||
patterns = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
)
|
||||
return any(pattern in cmdline for pattern in patterns)
|
||||
|
||||
|
||||
def _record_looks_like_gateway(record: dict[str, Any]) -> bool:
|
||||
"""Validate gateway identity from PID-file metadata when cmdline is unavailable."""
|
||||
if record.get("kind") != _GATEWAY_KIND:
|
||||
return False
|
||||
|
||||
argv = record.get("argv")
|
||||
if not isinstance(argv, list) or not argv:
|
||||
return False
|
||||
|
||||
cmdline = " ".join(str(part) for part in argv)
|
||||
patterns = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
)
|
||||
@@ -178,8 +197,8 @@ def write_runtime_status(
|
||||
payload = _read_json_file(path) or _build_runtime_status_record()
|
||||
payload.setdefault("platforms", {})
|
||||
payload.setdefault("kind", _GATEWAY_KIND)
|
||||
payload.setdefault("pid", os.getpid())
|
||||
payload.setdefault("start_time", _get_process_start_time(os.getpid()))
|
||||
payload["pid"] = os.getpid()
|
||||
payload["start_time"] = _get_process_start_time(os.getpid())
|
||||
payload["updated_at"] = _utc_now_iso()
|
||||
|
||||
if gateway_state is not None:
|
||||
@@ -255,6 +274,21 @@ def acquire_scoped_lock(scope: str, identity: str, metadata: Optional[dict[str,
|
||||
and current_start != existing.get("start_time")
|
||||
):
|
||||
stale = True
|
||||
# Check if process is stopped (Ctrl+Z / SIGTSTP) — stopped
|
||||
# processes still respond to os.kill(pid, 0) but are not
|
||||
# actually running. Treat them as stale so --replace works.
|
||||
if not stale:
|
||||
try:
|
||||
_proc_status = Path(f"/proc/{existing_pid}/status")
|
||||
if _proc_status.exists():
|
||||
for _line in _proc_status.read_text().splitlines():
|
||||
if _line.startswith("State:"):
|
||||
_state = _line.split()[1]
|
||||
if _state in ("T", "t"): # stopped or tracing stop
|
||||
stale = True
|
||||
break
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
if stale:
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
@@ -295,6 +329,25 @@ def release_scoped_lock(scope: str, identity: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def release_all_scoped_locks() -> int:
|
||||
"""Remove all scoped lock files in the lock directory.
|
||||
|
||||
Called during --replace to clean up stale locks left by stopped/killed
|
||||
gateway processes that did not release their locks gracefully.
|
||||
Returns the number of lock files removed.
|
||||
"""
|
||||
lock_dir = _get_lock_dir()
|
||||
removed = 0
|
||||
if lock_dir.exists():
|
||||
for lock_file in lock_dir.glob("*.lock"):
|
||||
try:
|
||||
lock_file.unlink(missing_ok=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
@@ -325,8 +378,9 @@ def get_running_pid() -> Optional[int]:
|
||||
return None
|
||||
|
||||
if not _looks_like_gateway_process(pid):
|
||||
remove_pid_file()
|
||||
return None
|
||||
if not _record_looks_like_gateway(record):
|
||||
remove_pid_file()
|
||||
return None
|
||||
|
||||
return pid
|
||||
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Gateway streaming consumer — bridges sync agent callbacks to async platform delivery.
|
||||
|
||||
The agent fires stream_delta_callback(text) synchronously from its worker thread.
|
||||
GatewayStreamConsumer:
|
||||
1. Receives deltas via on_delta() (thread-safe, sync)
|
||||
2. Queues them to an asyncio task via queue.Queue
|
||||
3. The async run() task buffers, rate-limits, and progressively edits
|
||||
a single message on the target platform
|
||||
|
||||
Design: Uses the edit transport (send initial message, then editMessageText).
|
||||
This is universally supported across Telegram, Discord, and Slack.
|
||||
|
||||
Credit: jobless0x (#774, #1312), OutThisLife (#798), clicksingh (#697).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger("gateway.stream_consumer")
|
||||
|
||||
# Sentinel to signal the stream is complete
|
||||
_DONE = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConsumerConfig:
|
||||
"""Runtime config for a single stream consumer instance."""
|
||||
edit_interval: float = 0.3
|
||||
buffer_threshold: int = 40
|
||||
cursor: str = " ▉"
|
||||
|
||||
|
||||
class GatewayStreamConsumer:
|
||||
"""Async consumer that progressively edits a platform message with streamed tokens.
|
||||
|
||||
Usage::
|
||||
|
||||
consumer = GatewayStreamConsumer(adapter, chat_id, config, metadata=metadata)
|
||||
# Pass consumer.on_delta as stream_delta_callback to AIAgent
|
||||
agent = AIAgent(..., stream_delta_callback=consumer.on_delta)
|
||||
# Start the consumer as an asyncio task
|
||||
task = asyncio.create_task(consumer.run())
|
||||
# ... run agent in thread pool ...
|
||||
consumer.finish() # signal completion
|
||||
await task # wait for final edit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Any,
|
||||
chat_id: str,
|
||||
config: Optional[StreamConsumerConfig] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
self.adapter = adapter
|
||||
self.chat_id = chat_id
|
||||
self.cfg = config or StreamConsumerConfig()
|
||||
self.metadata = metadata
|
||||
self._queue: queue.Queue = queue.Queue()
|
||||
self._accumulated = ""
|
||||
self._message_id: Optional[str] = None
|
||||
self._already_sent = False
|
||||
self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA)
|
||||
self._last_edit_time = 0.0
|
||||
self._last_sent_text = "" # Track last-sent text to skip redundant edits
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
"""True if at least one message was sent/edited — signals the base
|
||||
adapter to skip re-sending the final response."""
|
||||
return self._already_sent
|
||||
|
||||
def on_delta(self, text: str) -> None:
|
||||
"""Thread-safe callback — called from the agent's worker thread."""
|
||||
if text:
|
||||
self._queue.put(text)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Signal that the stream is complete."""
|
||||
self._queue.put(_DONE)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Async task that drains the queue and edits the platform message."""
|
||||
# Platform message length limit — leave room for cursor + formatting
|
||||
_raw_limit = getattr(self.adapter, "MAX_MESSAGE_LENGTH", 4096)
|
||||
_safe_limit = max(500, _raw_limit - len(self.cfg.cursor) - 100)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Drain all available items from the queue
|
||||
got_done = False
|
||||
while True:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
if item is _DONE:
|
||||
got_done = True
|
||||
break
|
||||
self._accumulated += item
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Decide whether to flush an edit
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_edit_time
|
||||
should_edit = (
|
||||
got_done
|
||||
or (elapsed >= self.cfg.edit_interval
|
||||
and len(self._accumulated) > 0)
|
||||
or len(self._accumulated) >= self.cfg.buffer_threshold
|
||||
)
|
||||
|
||||
if should_edit and self._accumulated:
|
||||
# Split overflow: if accumulated text exceeds the platform
|
||||
# limit, finalize the current message and start a new one.
|
||||
while (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is not None
|
||||
):
|
||||
split_at = self._accumulated.rfind("\n", 0, _safe_limit)
|
||||
if split_at < _safe_limit // 2:
|
||||
split_at = _safe_limit
|
||||
chunk = self._accumulated[:split_at]
|
||||
await self._send_or_edit(chunk)
|
||||
self._accumulated = self._accumulated[split_at:].lstrip("\n")
|
||||
self._message_id = None
|
||||
self._last_sent_text = ""
|
||||
|
||||
display_text = self._accumulated
|
||||
if not got_done:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
await self._send_or_edit(display_text)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
# Final edit without cursor
|
||||
if self._accumulated and self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Best-effort final edit on cancellation
|
||||
if self._accumulated and self._message_id:
|
||||
try:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Stream consumer error: %s", e)
|
||||
|
||||
async def _send_or_edit(self, text: str) -> None:
|
||||
"""Send or edit the streaming message."""
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Skip if text is identical to what we last sent
|
||||
if text == self._last_sent_text:
|
||||
return
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
else:
|
||||
# Edit not supported by this adapter — stop streaming,
|
||||
# let the normal send path handle the final response.
|
||||
# Without this guard, adapters like Signal/Email would
|
||||
# flood the chat with a new message every edit_interval.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._edit_supported = False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
pass
|
||||
else:
|
||||
# First message — send new
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = result.message_id
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
except Exception as e:
|
||||
logger.error("Stream send/edit error: %s", e)
|
||||
@@ -11,5 +11,5 @@ Provides subcommands for:
|
||||
- hermes cron - Manage cron jobs
|
||||
"""
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__release_date__ = "2026.3.12"
|
||||
__version__ = "0.4.0"
|
||||
__release_date__ = "2026.3.18"
|
||||
|
||||
+246
-16
@@ -19,6 +19,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import shlex
|
||||
import stat
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -66,6 +67,8 @@ DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
|
||||
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
@@ -108,6 +111,20 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
),
|
||||
"copilot": ProviderConfig(
|
||||
id="copilot",
|
||||
name="GitHub Copilot",
|
||||
auth_type="api_key",
|
||||
inference_base_url=DEFAULT_GITHUB_MODELS_BASE_URL,
|
||||
api_key_env_vars=("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN"),
|
||||
),
|
||||
"copilot-acp": ProviderConfig(
|
||||
id="copilot-acp",
|
||||
name="GitHub Copilot ACP",
|
||||
auth_type="external_process",
|
||||
inference_base_url=DEFAULT_COPILOT_ACP_BASE_URL,
|
||||
base_url_env_var="COPILOT_ACP_BASE_URL",
|
||||
),
|
||||
"zai": ProviderConfig(
|
||||
id="zai",
|
||||
name="Z.AI / GLM",
|
||||
@@ -128,7 +145,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
id="minimax",
|
||||
name="MiniMax",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimax.io/v1",
|
||||
inference_base_url="https://api.minimax.io/anthropic",
|
||||
api_key_env_vars=("MINIMAX_API_KEY",),
|
||||
base_url_env_var="MINIMAX_BASE_URL",
|
||||
),
|
||||
@@ -139,14 +156,62 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
inference_base_url="https://api.anthropic.com",
|
||||
api_key_env_vars=("ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"),
|
||||
),
|
||||
"alibaba": ProviderConfig(
|
||||
id="alibaba",
|
||||
name="Alibaba Cloud (DashScope)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://dashscope-intl.aliyuncs.com/apps/anthropic",
|
||||
api_key_env_vars=("DASHSCOPE_API_KEY",),
|
||||
base_url_env_var="DASHSCOPE_BASE_URL",
|
||||
),
|
||||
"minimax-cn": ProviderConfig(
|
||||
id="minimax-cn",
|
||||
name="MiniMax (China)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimaxi.com/v1",
|
||||
inference_base_url="https://api.minimaxi.com/anthropic",
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
"deepseek": ProviderConfig(
|
||||
id="deepseek",
|
||||
name="DeepSeek",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.deepseek.com/v1",
|
||||
api_key_env_vars=("DEEPSEEK_API_KEY",),
|
||||
base_url_env_var="DEEPSEEK_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="AI Gateway",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://ai-gateway.vercel.sh/v1",
|
||||
api_key_env_vars=("AI_GATEWAY_API_KEY",),
|
||||
base_url_env_var="AI_GATEWAY_BASE_URL",
|
||||
),
|
||||
"opencode-zen": ProviderConfig(
|
||||
id="opencode-zen",
|
||||
name="OpenCode Zen",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://opencode.ai/zen/v1",
|
||||
api_key_env_vars=("OPENCODE_ZEN_API_KEY",),
|
||||
base_url_env_var="OPENCODE_ZEN_BASE_URL",
|
||||
),
|
||||
"opencode-go": ProviderConfig(
|
||||
id="opencode-go",
|
||||
name="OpenCode Go",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://opencode.ai/zen/go/v1",
|
||||
api_key_env_vars=("OPENCODE_GO_API_KEY",),
|
||||
base_url_env_var="OPENCODE_GO_BASE_URL",
|
||||
),
|
||||
"kilocode": ProviderConfig(
|
||||
id="kilocode",
|
||||
name="Kilo Code",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.kilo.ai/api/gateway",
|
||||
api_key_env_vars=("KILOCODE_API_KEY",),
|
||||
base_url_env_var="KILOCODE_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -174,6 +239,97 @@ def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) ->
|
||||
return default_url
|
||||
|
||||
|
||||
def _gh_cli_candidates() -> list[str]:
|
||||
"""Return candidate ``gh`` binary paths, including common Homebrew installs."""
|
||||
candidates: list[str] = []
|
||||
|
||||
resolved = shutil.which("gh")
|
||||
if resolved:
|
||||
candidates.append(resolved)
|
||||
|
||||
for candidate in (
|
||||
"/opt/homebrew/bin/gh",
|
||||
"/usr/local/bin/gh",
|
||||
str(Path.home() / ".local" / "bin" / "gh"),
|
||||
):
|
||||
if candidate in candidates:
|
||||
continue
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
candidates.append(candidate)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _try_gh_cli_token() -> Optional[str]:
|
||||
"""Return a token from ``gh auth token`` when the GitHub CLI is available."""
|
||||
for gh_path in _gh_cli_candidates():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[gh_path, "auth", "token"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired) as exc:
|
||||
logger.debug("gh CLI token lookup failed (%s): %s", gh_path, exc)
|
||||
continue
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
return None
|
||||
|
||||
|
||||
_PLACEHOLDER_SECRET_VALUES = {
|
||||
"*",
|
||||
"**",
|
||||
"***",
|
||||
"changeme",
|
||||
"your_api_key",
|
||||
"your-api-key",
|
||||
"placeholder",
|
||||
"example",
|
||||
"dummy",
|
||||
"null",
|
||||
"none",
|
||||
}
|
||||
|
||||
|
||||
def has_usable_secret(value: Any, *, min_length: int = 4) -> bool:
|
||||
"""Return True when a configured secret looks usable, not empty/placeholder."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
cleaned = value.strip()
|
||||
if len(cleaned) < min_length:
|
||||
return False
|
||||
if cleaned.lower() in _PLACEHOLDER_SECRET_VALUES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _resolve_api_key_provider_secret(
|
||||
provider_id: str, pconfig: ProviderConfig
|
||||
) -> tuple[str, str]:
|
||||
"""Resolve an API-key provider's token and indicate where it came from."""
|
||||
if provider_id == "copilot":
|
||||
# Use the dedicated copilot auth module for proper token validation
|
||||
try:
|
||||
from hermes_cli.copilot_auth import resolve_copilot_token
|
||||
token, source = resolve_copilot_token()
|
||||
if token:
|
||||
return token, source
|
||||
except ValueError as exc:
|
||||
logger.warning("Copilot token validation failed: %s", exc)
|
||||
except Exception:
|
||||
pass
|
||||
return "", ""
|
||||
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if has_usable_secret(val):
|
||||
return val, env_var
|
||||
|
||||
return "", ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z.AI Endpoint Detection
|
||||
# =============================================================================
|
||||
@@ -524,6 +680,13 @@ def resolve_provider(
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic", "claude-code": "anthropic",
|
||||
"github": "copilot", "github-copilot": "copilot",
|
||||
"github-models": "copilot", "github-model": "copilot",
|
||||
"github-copilot-acp": "copilot-acp", "copilot-acp-agent": "copilot-acp",
|
||||
"aigateway": "ai-gateway", "vercel": "ai-gateway", "vercel-ai-gateway": "ai-gateway",
|
||||
"opencode": "opencode-zen", "zen": "opencode-zen",
|
||||
"go": "opencode-go", "opencode-go-sub": "opencode-go",
|
||||
"kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
@@ -552,15 +715,20 @@ def resolve_provider(
|
||||
except Exception as e:
|
||||
logger.debug("Could not detect active auth provider: %s", e)
|
||||
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
if has_usable_secret(os.getenv("OPENAI_API_KEY")) or has_usable_secret(os.getenv("OPENROUTER_API_KEY")):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
for pid, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# GitHub tokens are commonly present for repo/tool access but should not
|
||||
# hijack inference auto-selection unless the user explicitly chooses
|
||||
# Copilot/GitHub Models as the provider.
|
||||
if pid == "copilot":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
if has_usable_secret(os.getenv(env_var, "")):
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
@@ -1427,12 +1595,7 @@ def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
api_key, key_source = _resolve_api_key_provider_secret(provider_id, pconfig)
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
@@ -1455,6 +1618,36 @@ def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_external_process_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
"""Status snapshot for providers that run a local subprocess."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "external_process":
|
||||
return {"configured": False}
|
||||
|
||||
command = (
|
||||
os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip()
|
||||
or os.getenv("COPILOT_CLI_PATH", "").strip()
|
||||
or "copilot"
|
||||
)
|
||||
raw_args = os.getenv("HERMES_COPILOT_ACP_ARGS", "").strip()
|
||||
args = shlex.split(raw_args) if raw_args else ["--acp", "--stdio"]
|
||||
base_url = os.getenv(pconfig.base_url_env_var, "").strip() if pconfig.base_url_env_var else ""
|
||||
if not base_url:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
resolved_command = shutil.which(command) if command else None
|
||||
return {
|
||||
"configured": bool(resolved_command or base_url.startswith("acp+tcp://")),
|
||||
"provider": provider_id,
|
||||
"name": pconfig.name,
|
||||
"command": command,
|
||||
"args": args,
|
||||
"resolved_command": resolved_command,
|
||||
"base_url": base_url,
|
||||
"logged_in": bool(resolved_command or base_url.startswith("acp+tcp://")),
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
@@ -1462,6 +1655,8 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return get_nous_auth_status()
|
||||
if target == "openai-codex":
|
||||
return get_codex_auth_status()
|
||||
if target == "copilot-acp":
|
||||
return get_external_process_provider_status(target)
|
||||
# API-key providers
|
||||
pconfig = PROVIDER_REGISTRY.get(target)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
@@ -1484,12 +1679,7 @@ def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
api_key, key_source = _resolve_api_key_provider_secret(provider_id, pconfig)
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
@@ -1510,6 +1700,46 @@ def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def resolve_external_process_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
"""Resolve runtime details for local subprocess-backed providers."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "external_process":
|
||||
raise AuthError(
|
||||
f"Provider '{provider_id}' is not an external-process provider.",
|
||||
provider=provider_id,
|
||||
code="invalid_provider",
|
||||
)
|
||||
|
||||
base_url = os.getenv(pconfig.base_url_env_var, "").strip() if pconfig.base_url_env_var else ""
|
||||
if not base_url:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
command = (
|
||||
os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip()
|
||||
or os.getenv("COPILOT_CLI_PATH", "").strip()
|
||||
or "copilot"
|
||||
)
|
||||
raw_args = os.getenv("HERMES_COPILOT_ACP_ARGS", "").strip()
|
||||
args = shlex.split(raw_args) if raw_args else ["--acp", "--stdio"]
|
||||
resolved_command = shutil.which(command) if command else None
|
||||
if not resolved_command and not base_url.startswith("acp+tcp://"):
|
||||
raise AuthError(
|
||||
f"Could not find the Copilot CLI command '{command}'. "
|
||||
"Install GitHub Copilot CLI or set HERMES_COPILOT_ACP_COMMAND/COPILOT_CLI_PATH.",
|
||||
provider=provider_id,
|
||||
code="missing_copilot_cli",
|
||||
)
|
||||
|
||||
return {
|
||||
"provider": provider_id,
|
||||
"api_key": "copilot-acp",
|
||||
"base_url": base_url.rstrip("/"),
|
||||
"command": resolved_command or command,
|
||||
"args": args,
|
||||
"source": "process",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
+35
-27
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
# ANSI building blocks for conversation display
|
||||
# =========================================================================
|
||||
|
||||
_GOLD = "\033[1;33m"
|
||||
_GOLD = "\033[1;38;2;255;215;0m" # True-color #FFD700 bold
|
||||
_BOLD = "\033[1m"
|
||||
_DIM = "\033[2m"
|
||||
_RST = "\033[0m"
|
||||
@@ -102,27 +102,22 @@ COMPACT_BANNER = """
|
||||
# =========================================================================
|
||||
|
||||
def get_available_skills() -> Dict[str, List[str]]:
|
||||
"""Scan ~/.hermes/skills/ and return skills grouped by category."""
|
||||
import os
|
||||
"""Return skills grouped by category, filtered by platform and disabled state.
|
||||
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
skills_dir = hermes_home / "skills"
|
||||
skills_by_category = {}
|
||||
|
||||
if not skills_dir.exists():
|
||||
return skills_by_category
|
||||
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
category = parts[0]
|
||||
skill_name = parts[-2]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
skills_by_category.setdefault(category, []).append(skill_name)
|
||||
Delegates to ``_find_all_skills()`` from ``tools/skills_tool`` which already
|
||||
handles platform gating (``platforms:`` frontmatter) and respects the
|
||||
user's ``skills.disabled`` config list.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _find_all_skills
|
||||
all_skills = _find_all_skills() # already filtered
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
skills_by_category: Dict[str, List[str]] = {}
|
||||
for skill in all_skills:
|
||||
category = skill.get("category") or "general"
|
||||
skills_by_category.setdefault(category, []).append(skill["name"])
|
||||
return skills_by_category
|
||||
|
||||
|
||||
@@ -233,6 +228,17 @@ def _format_context_length(tokens: int) -> str:
|
||||
return str(tokens)
|
||||
|
||||
|
||||
def _display_toolset_name(toolset_name: str) -> str:
|
||||
"""Normalize internal/legacy toolset identifiers for banner display."""
|
||||
if not toolset_name:
|
||||
return "unknown"
|
||||
return (
|
||||
toolset_name[:-6]
|
||||
if toolset_name.endswith("_tools")
|
||||
else toolset_name
|
||||
)
|
||||
|
||||
|
||||
def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
tools: List[dict] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
@@ -283,6 +289,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
_hero = HERMES_CADUCEUS
|
||||
left_lines = ["", _hero, ""]
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if model_short.endswith(".gguf"):
|
||||
model_short = model_short[:-5]
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
ctx_str = f" [dim {dim}]·[/] [dim {dim}]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
@@ -297,12 +305,12 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
|
||||
for tool in tools:
|
||||
tool_name = tool["function"]["name"]
|
||||
toolset = get_toolset_for_tool(tool_name) or "other"
|
||||
toolset = _display_toolset_name(get_toolset_for_tool(tool_name) or "other")
|
||||
toolsets_dict.setdefault(toolset, []).append(tool_name)
|
||||
|
||||
for item in unavailable_toolsets:
|
||||
toolset_id = item.get("id", item.get("name", "unknown"))
|
||||
display_name = f"{toolset_id}_tools" if not toolset_id.endswith("_tools") else toolset_id
|
||||
display_name = _display_toolset_name(toolset_id)
|
||||
if display_name not in toolsets_dict:
|
||||
toolsets_dict[display_name] = []
|
||||
for tool_name in item.get("tools", []):
|
||||
@@ -342,10 +350,10 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
colored_names.append(f"[{text}]{name}[/]")
|
||||
tools_str = ", ".join(colored_names)
|
||||
|
||||
right_lines.append(f"[dim #B8860B]{toolset}:[/] {tools_str}")
|
||||
right_lines.append(f"[dim {dim}]{toolset}:[/] {tools_str}")
|
||||
|
||||
if remaining_toolsets > 0:
|
||||
right_lines.append(f"[dim #B8860B](and {remaining_toolsets} more toolsets...)[/]")
|
||||
right_lines.append(f"[dim {dim}](and {remaining_toolsets} more toolsets...)[/]")
|
||||
|
||||
# MCP Servers section (only if configured)
|
||||
try:
|
||||
@@ -356,12 +364,12 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
|
||||
if mcp_status:
|
||||
right_lines.append("")
|
||||
right_lines.append("[bold #FFBF00]MCP Servers[/]")
|
||||
right_lines.append(f"[bold {accent}]MCP Servers[/]")
|
||||
for srv in mcp_status:
|
||||
if srv["connected"]:
|
||||
right_lines.append(
|
||||
f"[dim #B8860B]{srv['name']}[/] [#FFF8DC]({srv['transport']})[/] "
|
||||
f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]"
|
||||
f"[dim {dim}]{srv['name']}[/] [{text}]({srv['transport']})[/] "
|
||||
f"[dim {dim}]—[/] [{text}]{srv['tools']} tool(s)[/]"
|
||||
)
|
||||
else:
|
||||
right_lines.append(
|
||||
|
||||
@@ -294,3 +294,18 @@ def _print_migration_report(report: dict, dry_run: bool):
|
||||
elif migrated:
|
||||
print()
|
||||
print_success("Migration complete!")
|
||||
# Warn if API keys were skipped (migrate_secrets not enabled)
|
||||
skipped_keys = [
|
||||
i for i in report.get("items", [])
|
||||
if i.get("kind") == "provider-keys" and i.get("status") == "skipped"
|
||||
]
|
||||
if skipped_keys:
|
||||
print()
|
||||
print(color(" ⚠ API keys were NOT migrated (secrets migration is disabled by default).", Colors.YELLOW))
|
||||
print(color(" Your OPENROUTER_API_KEY and other provider keys must be added manually.", Colors.YELLOW))
|
||||
print()
|
||||
print_info("To migrate API keys, re-run with:")
|
||||
print_info(" hermes claw migrate --migrate-secrets")
|
||||
print()
|
||||
print_info("Or add your key manually:")
|
||||
print_info(" hermes config set OPENROUTER_API_KEY sk-or-v1-...")
|
||||
|
||||
+703
-52
@@ -1,77 +1,359 @@
|
||||
"""Slash command definitions and autocomplete for the Hermes CLI.
|
||||
|
||||
Contains the shared built-in ``COMMANDS`` dict and ``SlashCommandCompleter``.
|
||||
The completer can optionally include dynamic skill slash commands supplied by the
|
||||
interactive CLI.
|
||||
Central registry for all slash commands. Every consumer -- CLI help, gateway
|
||||
dispatch, Telegram BotCommands, Slack subcommand mapping, autocomplete --
|
||||
derives its data from ``COMMAND_REGISTRY``.
|
||||
|
||||
To add a command: add a ``CommandDef`` entry to ``COMMAND_REGISTRY``.
|
||||
To add an alias: set ``aliases=("short",)`` on the existing ``CommandDef``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
# Commands organized by category for better help display
|
||||
COMMANDS_BY_CATEGORY = {
|
||||
"Session": {
|
||||
"/new": "Start a new session (fresh session ID + history)",
|
||||
"/reset": "Start a new session (alias for /new)",
|
||||
"/clear": "Clear screen and start a new session",
|
||||
"/history": "Show conversation history",
|
||||
"/save": "Save the current conversation",
|
||||
"/retry": "Retry the last message (resend to agent)",
|
||||
"/undo": "Remove the last user/assistant exchange",
|
||||
"/title": "Set a title for the current session (usage: /title My Session Name)",
|
||||
"/compress": "Manually compress conversation context (flush memories + summarize)",
|
||||
"/rollback": "List or restore filesystem checkpoints (usage: /rollback [number])",
|
||||
"/background": "Run a prompt in the background (usage: /background <prompt>)",
|
||||
},
|
||||
"Configuration": {
|
||||
"/config": "Show current configuration",
|
||||
"/model": "Show or change the current model",
|
||||
"/provider": "Show available providers and current provider",
|
||||
"/prompt": "View/set custom system prompt",
|
||||
"/personality": "Set a predefined personality",
|
||||
"/verbose": "Cycle tool progress display: off → new → all → verbose",
|
||||
"/reasoning": "Manage reasoning effort and display (usage: /reasoning [level|show|hide])",
|
||||
"/skin": "Show or change the display skin/theme",
|
||||
"/voice": "Toggle voice mode (Ctrl+B to record). Usage: /voice [on|off|tts|status]",
|
||||
},
|
||||
"Tools & Skills": {
|
||||
"/tools": "List available tools",
|
||||
"/toolsets": "List available toolsets",
|
||||
"/skills": "Search, install, inspect, or manage skills from online registries",
|
||||
"/cron": "Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove)",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
},
|
||||
"Info": {
|
||||
"/help": "Show this help message",
|
||||
"/usage": "Show token usage for the current session",
|
||||
"/insights": "Show usage insights and analytics (last 30 days)",
|
||||
"/platforms": "Show gateway/messaging platform status",
|
||||
"/paste": "Check clipboard for an image and attach it",
|
||||
},
|
||||
"Exit": {
|
||||
"/quit": "Exit the CLI (also: /exit, /q)",
|
||||
},
|
||||
}
|
||||
# ---------------------------------------------------------------------------
|
||||
# CommandDef dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Flat dict for backwards compatibility and autocomplete
|
||||
COMMANDS = {}
|
||||
for category_commands in COMMANDS_BY_CATEGORY.values():
|
||||
COMMANDS.update(category_commands)
|
||||
@dataclass(frozen=True)
|
||||
class CommandDef:
|
||||
"""Definition of a single slash command."""
|
||||
|
||||
name: str # canonical name without slash: "background"
|
||||
description: str # human-readable description
|
||||
category: str # "Session", "Configuration", etc.
|
||||
aliases: tuple[str, ...] = () # alternative names: ("bg",)
|
||||
args_hint: str = "" # argument placeholder: "<prompt>", "[name]"
|
||||
subcommands: tuple[str, ...] = () # tab-completable subcommands
|
||||
cli_only: bool = False # only available in CLI
|
||||
gateway_only: bool = False # only available in gateway/messaging
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Central registry -- single source of truth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COMMAND_REGISTRY: list[CommandDef] = [
|
||||
# Session
|
||||
CommandDef("new", "Start a new session (fresh session ID + history)", "Session",
|
||||
aliases=("reset",)),
|
||||
CommandDef("clear", "Clear screen and start a new session", "Session",
|
||||
cli_only=True),
|
||||
CommandDef("history", "Show conversation history", "Session",
|
||||
cli_only=True),
|
||||
CommandDef("save", "Save the current conversation", "Session",
|
||||
cli_only=True),
|
||||
CommandDef("retry", "Retry the last message (resend to agent)", "Session"),
|
||||
CommandDef("undo", "Remove the last user/assistant exchange", "Session"),
|
||||
CommandDef("title", "Set a title for the current session", "Session",
|
||||
args_hint="[name]"),
|
||||
CommandDef("compress", "Manually compress conversation context", "Session"),
|
||||
CommandDef("rollback", "List or restore filesystem checkpoints", "Session",
|
||||
args_hint="[number]"),
|
||||
CommandDef("stop", "Kill all running background processes", "Session"),
|
||||
CommandDef("approve", "Approve a pending dangerous command", "Session",
|
||||
gateway_only=True, args_hint="[session|always]"),
|
||||
CommandDef("deny", "Deny a pending dangerous command", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("background", "Run a prompt in the background", "Session",
|
||||
aliases=("bg",), args_hint="<prompt>"),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
gateway_only=True, aliases=("set-home",)),
|
||||
CommandDef("resume", "Resume a previously-named session", "Session",
|
||||
args_hint="[name]"),
|
||||
|
||||
# Configuration
|
||||
CommandDef("config", "Show current configuration", "Configuration",
|
||||
cli_only=True),
|
||||
CommandDef("model", "Show or change the current model", "Configuration",
|
||||
args_hint="[name]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
CommandDef("prompt", "View/set custom system prompt", "Configuration",
|
||||
cli_only=True, args_hint="[text]", subcommands=("clear",)),
|
||||
CommandDef("personality", "Set a predefined personality", "Configuration",
|
||||
args_hint="[name]"),
|
||||
CommandDef("statusbar", "Toggle the context/model status bar", "Configuration",
|
||||
cli_only=True, aliases=("sb",)),
|
||||
CommandDef("verbose", "Cycle tool progress display: off -> new -> all -> verbose",
|
||||
"Configuration", cli_only=True),
|
||||
CommandDef("reasoning", "Manage reasoning effort and display", "Configuration",
|
||||
args_hint="[level|show|hide]",
|
||||
subcommands=("none", "low", "minimal", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
args_hint="[on|off|tts|status]", subcommands=("on", "off", "tts", "status")),
|
||||
|
||||
# Tools & Skills
|
||||
CommandDef("tools", "Manage tools: /tools [list|disable|enable] [name...]", "Tools & Skills",
|
||||
args_hint="[list|disable|enable] [name...]", cli_only=True),
|
||||
CommandDef("toolsets", "List available toolsets", "Tools & Skills",
|
||||
cli_only=True),
|
||||
CommandDef("skills", "Search, install, inspect, or manage skills",
|
||||
"Tools & Skills", cli_only=True,
|
||||
subcommands=("search", "browse", "inspect", "install")),
|
||||
CommandDef("cron", "Manage scheduled tasks", "Tools & Skills",
|
||||
cli_only=True, args_hint="[subcommand]",
|
||||
subcommands=("list", "add", "create", "edit", "pause", "resume", "run", "remove")),
|
||||
CommandDef("reload-mcp", "Reload MCP servers from config", "Tools & Skills",
|
||||
aliases=("reload_mcp",)),
|
||||
CommandDef("browser", "Connect browser tools to your live Chrome via CDP", "Tools & Skills",
|
||||
cli_only=True, args_hint="[connect|disconnect|status]",
|
||||
subcommands=("connect", "disconnect", "status")),
|
||||
CommandDef("plugins", "List installed plugins and their status",
|
||||
"Tools & Skills", cli_only=True),
|
||||
|
||||
# Info
|
||||
CommandDef("help", "Show available commands", "Info"),
|
||||
CommandDef("usage", "Show token usage for the current session", "Info"),
|
||||
CommandDef("insights", "Show usage insights and analytics", "Info",
|
||||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
cli_only=True, aliases=("gateway",)),
|
||||
CommandDef("paste", "Check clipboard for an image and attach it", "Info",
|
||||
cli_only=True),
|
||||
CommandDef("update", "Update Hermes Agent to the latest version", "Info",
|
||||
gateway_only=True),
|
||||
|
||||
# Exit
|
||||
CommandDef("quit", "Exit the CLI", "Exit",
|
||||
cli_only=True, aliases=("exit", "q")),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Derived lookups -- rebuilt once at import time, refreshed by rebuild_lookups()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_command_lookup() -> dict[str, CommandDef]:
|
||||
"""Map every name and alias to its CommandDef."""
|
||||
lookup: dict[str, CommandDef] = {}
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
lookup[cmd.name] = cmd
|
||||
for alias in cmd.aliases:
|
||||
lookup[alias] = cmd
|
||||
return lookup
|
||||
|
||||
|
||||
_COMMAND_LOOKUP: dict[str, CommandDef] = _build_command_lookup()
|
||||
|
||||
|
||||
def resolve_command(name: str) -> CommandDef | None:
|
||||
"""Resolve a command name or alias to its CommandDef.
|
||||
|
||||
Accepts names with or without the leading slash.
|
||||
"""
|
||||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def register_plugin_command(cmd: CommandDef) -> None:
|
||||
"""Append a plugin-defined command to the registry and refresh lookups."""
|
||||
COMMAND_REGISTRY.append(cmd)
|
||||
rebuild_lookups()
|
||||
|
||||
|
||||
def rebuild_lookups() -> None:
|
||||
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
|
||||
|
||||
Called after plugin commands are registered so they appear in help,
|
||||
autocomplete, gateway dispatch, Telegram menu, and Slack mapping.
|
||||
"""
|
||||
global GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
_COMMAND_LOOKUP.clear()
|
||||
_COMMAND_LOOKUP.update(_build_command_lookup())
|
||||
|
||||
COMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
COMMANDS[f"/{cmd.name}"] = _build_description(cmd)
|
||||
for alias in cmd.aliases:
|
||||
COMMANDS[f"/{alias}"] = f"{cmd.description} (alias for /{cmd.name})"
|
||||
|
||||
COMMANDS_BY_CATEGORY.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
cat = COMMANDS_BY_CATEGORY.setdefault(cmd.category, {})
|
||||
cat[f"/{cmd.name}"] = COMMANDS[f"/{cmd.name}"]
|
||||
for alias in cmd.aliases:
|
||||
cat[f"/{alias}"] = COMMANDS[f"/{alias}"]
|
||||
|
||||
SUBCOMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.subcommands:
|
||||
SUBCOMMANDS[f"/{cmd.name}"] = list(cmd.subcommands)
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
key = f"/{cmd.name}"
|
||||
if key in SUBCOMMANDS or not cmd.args_hint:
|
||||
continue
|
||||
m = _PIPE_SUBS_RE.search(cmd.args_hint)
|
||||
if m:
|
||||
SUBCOMMANDS[key] = m.group(0).split("|")
|
||||
|
||||
GATEWAY_KNOWN_COMMANDS = frozenset(
|
||||
name
|
||||
for cmd in COMMAND_REGISTRY
|
||||
if not cmd.cli_only
|
||||
for name in (cmd.name, *cmd.aliases)
|
||||
)
|
||||
|
||||
|
||||
def _build_description(cmd: CommandDef) -> str:
|
||||
"""Build a CLI-facing description string including usage hint."""
|
||||
if cmd.args_hint:
|
||||
return f"{cmd.description} (usage: /{cmd.name} {cmd.args_hint})"
|
||||
return cmd.description
|
||||
|
||||
|
||||
# Backwards-compatible flat dict: "/command" -> description
|
||||
COMMANDS: dict[str, str] = {}
|
||||
for _cmd in COMMAND_REGISTRY:
|
||||
if not _cmd.gateway_only:
|
||||
COMMANDS[f"/{_cmd.name}"] = _build_description(_cmd)
|
||||
for _alias in _cmd.aliases:
|
||||
COMMANDS[f"/{_alias}"] = f"{_cmd.description} (alias for /{_cmd.name})"
|
||||
|
||||
# Backwards-compatible categorized dict
|
||||
COMMANDS_BY_CATEGORY: dict[str, dict[str, str]] = {}
|
||||
for _cmd in COMMAND_REGISTRY:
|
||||
if not _cmd.gateway_only:
|
||||
_cat = COMMANDS_BY_CATEGORY.setdefault(_cmd.category, {})
|
||||
_cat[f"/{_cmd.name}"] = COMMANDS[f"/{_cmd.name}"]
|
||||
for _alias in _cmd.aliases:
|
||||
_cat[f"/{_alias}"] = COMMANDS[f"/{_alias}"]
|
||||
|
||||
|
||||
# Subcommands lookup: "/cmd" -> ["sub1", "sub2", ...]
|
||||
SUBCOMMANDS: dict[str, list[str]] = {}
|
||||
for _cmd in COMMAND_REGISTRY:
|
||||
if _cmd.subcommands:
|
||||
SUBCOMMANDS[f"/{_cmd.name}"] = list(_cmd.subcommands)
|
||||
|
||||
# Also extract subcommands hinted in args_hint via pipe-separated patterns
|
||||
# e.g. args_hint="[on|off|tts|status]" for commands that don't have explicit subcommands.
|
||||
# NOTE: If a command already has explicit subcommands, this fallback is skipped.
|
||||
# Use the `subcommands` field on CommandDef for intentional tab-completable args.
|
||||
_PIPE_SUBS_RE = re.compile(r"[a-z]+(?:\|[a-z]+)+")
|
||||
for _cmd in COMMAND_REGISTRY:
|
||||
key = f"/{_cmd.name}"
|
||||
if key in SUBCOMMANDS or not _cmd.args_hint:
|
||||
continue
|
||||
m = _PIPE_SUBS_RE.search(_cmd.args_hint)
|
||||
if m:
|
||||
SUBCOMMANDS[key] = m.group(0).split("|")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gateway helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Set of all command names + aliases recognized by the gateway
|
||||
GATEWAY_KNOWN_COMMANDS: frozenset[str] = frozenset(
|
||||
name
|
||||
for cmd in COMMAND_REGISTRY
|
||||
if not cmd.cli_only
|
||||
for name in (cmd.name, *cmd.aliases)
|
||||
)
|
||||
|
||||
|
||||
def gateway_help_lines() -> list[str]:
|
||||
"""Generate gateway help text lines from the registry."""
|
||||
lines: list[str] = []
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
continue
|
||||
args = f" {cmd.args_hint}" if cmd.args_hint else ""
|
||||
alias_parts: list[str] = []
|
||||
for a in cmd.aliases:
|
||||
# Skip internal aliases like reload_mcp (underscore variant)
|
||||
if a.replace("-", "_") == cmd.name.replace("-", "_") and a != cmd.name:
|
||||
continue
|
||||
alias_parts.append(f"`/{a}`")
|
||||
alias_note = f" (alias: {', '.join(alias_parts)})" if alias_parts else ""
|
||||
lines.append(f"`/{cmd.name}{args}` -- {cmd.description}{alias_note}")
|
||||
return lines
|
||||
|
||||
|
||||
def telegram_bot_commands() -> list[tuple[str, str]]:
|
||||
"""Return (command_name, description) pairs for Telegram setMyCommands.
|
||||
|
||||
Telegram command names cannot contain hyphens, so they are replaced with
|
||||
underscores. Aliases are skipped -- Telegram shows one menu entry per
|
||||
canonical command.
|
||||
"""
|
||||
result: list[tuple[str, str]] = []
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
continue
|
||||
tg_name = cmd.name.replace("-", "_")
|
||||
result.append((tg_name, cmd.description))
|
||||
return result
|
||||
|
||||
|
||||
def slack_subcommand_map() -> dict[str, str]:
|
||||
"""Return subcommand -> /command mapping for Slack /hermes handler.
|
||||
|
||||
Maps both canonical names and aliases so /hermes bg do stuff works
|
||||
the same as /hermes background do stuff.
|
||||
"""
|
||||
mapping: dict[str, str] = {}
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
continue
|
||||
mapping[cmd.name] = f"/{cmd.name}"
|
||||
for alias in cmd.aliases:
|
||||
mapping[alias] = f"/{alias}"
|
||||
return mapping
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Autocomplete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SlashCommandCompleter(Completer):
|
||||
"""Autocomplete for built-in slash commands and optional skill commands."""
|
||||
"""Autocomplete for built-in slash commands, subcommands, and skill commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
model_completer_provider: Callable[[], dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
# model_completer_provider returns {"current_provider": str,
|
||||
# "providers": {id: label, ...}, "models_for": callable(provider) -> list[str]}
|
||||
self._model_completer_provider = model_completer_provider
|
||||
self._model_info_cache: dict[str, Any] | None = None
|
||||
self._model_info_cache_time: float = 0
|
||||
|
||||
def _get_model_info(self) -> dict[str, Any]:
|
||||
"""Get cached model/provider info for /model autocomplete."""
|
||||
import time
|
||||
now = time.monotonic()
|
||||
if self._model_info_cache is not None and now - self._model_info_cache_time < 60:
|
||||
return self._model_info_cache
|
||||
if self._model_completer_provider is None:
|
||||
return {}
|
||||
try:
|
||||
self._model_info_cache = self._model_completer_provider() or {}
|
||||
self._model_info_cache_time = now
|
||||
except Exception:
|
||||
self._model_info_cache = self._model_info_cache or {}
|
||||
return self._model_info_cache
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
@@ -92,9 +374,279 @@ class SlashCommandCompleter(Completer):
|
||||
"""
|
||||
return f"{cmd_name} " if cmd_name == word else cmd_name
|
||||
|
||||
@staticmethod
|
||||
def _extract_path_word(text: str) -> str | None:
|
||||
"""Extract the current word if it looks like a file path.
|
||||
|
||||
Returns the path-like token under the cursor, or None if the
|
||||
current word doesn't look like a path. A word is path-like when
|
||||
it starts with ``./``, ``../``, ``~/``, ``/``, or contains a
|
||||
``/`` separator (e.g. ``src/main.py``).
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
# Walk backwards to find the start of the current "word".
|
||||
# Words are delimited by spaces, but paths can contain almost anything.
|
||||
i = len(text) - 1
|
||||
while i >= 0 and text[i] != " ":
|
||||
i -= 1
|
||||
word = text[i + 1:]
|
||||
if not word:
|
||||
return None
|
||||
# Only trigger path completion for path-like tokens
|
||||
if word.startswith(("./", "../", "~/", "/")) or "/" in word:
|
||||
return word
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _path_completions(word: str, limit: int = 30):
|
||||
"""Yield Completion objects for file paths matching *word*."""
|
||||
expanded = os.path.expanduser(word)
|
||||
# Split into directory part and prefix to match inside it
|
||||
if expanded.endswith("/"):
|
||||
search_dir = expanded
|
||||
prefix = ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
|
||||
# Build the completion text (what replaces the typed word)
|
||||
if word.startswith("~"):
|
||||
display_path = "~/" + os.path.relpath(full_path, os.path.expanduser("~"))
|
||||
elif os.path.isabs(word):
|
||||
display_path = full_path
|
||||
else:
|
||||
# Keep relative
|
||||
display_path = os.path.relpath(full_path)
|
||||
|
||||
if is_dir:
|
||||
display_path += "/"
|
||||
|
||||
suffix = "/" if is_dir else ""
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
|
||||
yield Completion(
|
||||
display_path,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
|
||||
@staticmethod
|
||||
def _extract_context_word(text: str) -> str | None:
|
||||
"""Extract a bare ``@`` token for context reference completions."""
|
||||
if not text:
|
||||
return None
|
||||
# Walk backwards to find the start of the current word
|
||||
i = len(text) - 1
|
||||
while i >= 0 and text[i] != " ":
|
||||
i -= 1
|
||||
word = text[i + 1:]
|
||||
if not word.startswith("@"):
|
||||
return None
|
||||
return word
|
||||
|
||||
@staticmethod
|
||||
def _context_completions(word: str, limit: int = 30):
|
||||
"""Yield Claude Code-style @ context completions.
|
||||
|
||||
Bare ``@`` or ``@partial`` shows static references and matching
|
||||
files/folders. ``@file:path`` and ``@folder:path`` are handled
|
||||
by the existing path completion path.
|
||||
"""
|
||||
lowered = word.lower()
|
||||
|
||||
# Static context references
|
||||
_STATIC_REFS = (
|
||||
("@diff", "Git working tree diff"),
|
||||
("@staged", "Git staged diff"),
|
||||
("@file:", "Attach a file"),
|
||||
("@folder:", "Attach a folder"),
|
||||
("@git:", "Git log with diffs (e.g. @git:5)"),
|
||||
("@url:", "Fetch web content"),
|
||||
)
|
||||
for candidate, meta in _STATIC_REFS:
|
||||
if candidate.lower().startswith(lowered) and candidate.lower() != lowered:
|
||||
yield Completion(
|
||||
candidate,
|
||||
start_position=-len(word),
|
||||
display=candidate,
|
||||
display_meta=meta,
|
||||
)
|
||||
|
||||
# If the user typed @file: or @folder:, delegate to path completions
|
||||
for prefix in ("@file:", "@folder:"):
|
||||
if word.startswith(prefix):
|
||||
path_part = word[len(prefix):] or "."
|
||||
expanded = os.path.expanduser(path_part)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
return
|
||||
|
||||
# Bare @ or @partial — show matching files/folders from cwd
|
||||
query = word[1:] # strip the @
|
||||
if not query:
|
||||
search_dir, match_prefix = ".", ""
|
||||
else:
|
||||
expanded = os.path.expanduser(query)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if entry.startswith("."):
|
||||
continue # skip hidden files in bare @ mode
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
# Try @ context completion (Claude Code-style)
|
||||
ctx_word = self._extract_context_word(text)
|
||||
if ctx_word is not None:
|
||||
yield from self._context_completions(ctx_word)
|
||||
return
|
||||
# Try file path completion for non-slash input
|
||||
path_word = self._extract_path_word(text)
|
||||
if path_word is not None:
|
||||
yield from self._path_completions(path_word)
|
||||
return
|
||||
|
||||
# Check if we're completing a subcommand (base command already typed)
|
||||
parts = text.split(maxsplit=1)
|
||||
base_cmd = parts[0].lower()
|
||||
if len(parts) > 1 or (len(parts) == 1 and text.endswith(" ")):
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# /model gets two-stage completion:
|
||||
# Stage 1: provider names (with : suffix)
|
||||
# Stage 2: after "provider:", list that provider's models
|
||||
if base_cmd == "/model" and " " not in sub_text:
|
||||
info = self._get_model_info()
|
||||
if info:
|
||||
current_prov = info.get("current_provider", "")
|
||||
providers = info.get("providers", {})
|
||||
models_for = info.get("models_for")
|
||||
|
||||
if ":" in sub_text:
|
||||
# Stage 2: "anthropic:cl" → models for anthropic
|
||||
prov_part, model_part = sub_text.split(":", 1)
|
||||
model_lower = model_part.lower()
|
||||
if models_for:
|
||||
try:
|
||||
prov_models = models_for(prov_part)
|
||||
except Exception:
|
||||
prov_models = []
|
||||
for mid in prov_models:
|
||||
if mid.lower().startswith(model_lower) and mid.lower() != model_lower:
|
||||
full = f"{prov_part}:{mid}"
|
||||
yield Completion(
|
||||
full,
|
||||
start_position=-len(sub_text),
|
||||
display=mid,
|
||||
)
|
||||
else:
|
||||
# Stage 1: providers sorted: non-current first, current last
|
||||
for pid, plabel in sorted(
|
||||
providers.items(),
|
||||
key=lambda kv: (kv[0] == current_prov, kv[0]),
|
||||
):
|
||||
display_name = f"{pid}:"
|
||||
if display_name.lower().startswith(sub_lower):
|
||||
meta = f"({plabel})" if plabel != pid else ""
|
||||
if pid == current_prov:
|
||||
meta = f"(current — {plabel})" if plabel != pid else "(current)"
|
||||
yield Completion(
|
||||
display_name,
|
||||
start_position=-len(sub_text),
|
||||
display=display_name,
|
||||
display_meta=meta,
|
||||
)
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS:
|
||||
for sub in SUBCOMMANDS[base_cmd]:
|
||||
if sub.startswith(sub_lower) and sub != sub_lower:
|
||||
yield Completion(
|
||||
sub,
|
||||
start_position=-len(sub_text),
|
||||
display=sub,
|
||||
)
|
||||
return
|
||||
|
||||
word = text[1:]
|
||||
@@ -120,3 +672,102 @@ class SlashCommandCompleter(Completer):
|
||||
display=cmd,
|
||||
display_meta=f"⚡ {short_desc}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline auto-suggest (ghost text) for slash commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SlashCommandAutoSuggest(AutoSuggest):
|
||||
"""Inline ghost-text suggestions for slash commands and their subcommands.
|
||||
|
||||
Shows the rest of a command or subcommand in dim text as you type.
|
||||
Falls back to history-based suggestions for non-slash input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_suggest: AutoSuggest | None = None,
|
||||
completer: SlashCommandCompleter | None = None,
|
||||
) -> None:
|
||||
self._history = history_suggest
|
||||
self._completer = completer # Reuse its model cache
|
||||
|
||||
def get_suggestion(self, buffer, document):
|
||||
text = document.text_before_cursor
|
||||
|
||||
# Only suggest for slash commands
|
||||
if not text.startswith("/"):
|
||||
# Fall back to history for regular text
|
||||
if self._history:
|
||||
return self._history.get_suggestion(buffer, document)
|
||||
return None
|
||||
|
||||
parts = text.split(maxsplit=1)
|
||||
base_cmd = parts[0].lower()
|
||||
|
||||
if len(parts) == 1 and not text.endswith(" "):
|
||||
# Still typing the command name: /upd → suggest "ate"
|
||||
word = text[1:].lower()
|
||||
for cmd in COMMANDS:
|
||||
cmd_name = cmd[1:] # strip leading /
|
||||
if cmd_name.startswith(word) and cmd_name != word:
|
||||
return Suggestion(cmd_name[len(word):])
|
||||
return None
|
||||
|
||||
# Command is complete — suggest subcommands or model names
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# /model gets two-stage ghost text
|
||||
if base_cmd == "/model" and " " not in sub_text and self._completer:
|
||||
info = self._completer._get_model_info()
|
||||
if info:
|
||||
providers = info.get("providers", {})
|
||||
models_for = info.get("models_for")
|
||||
current_prov = info.get("current_provider", "")
|
||||
|
||||
if ":" in sub_text:
|
||||
# Stage 2: after provider:, suggest model
|
||||
prov_part, model_part = sub_text.split(":", 1)
|
||||
model_lower = model_part.lower()
|
||||
if models_for:
|
||||
try:
|
||||
for mid in models_for(prov_part):
|
||||
if mid.lower().startswith(model_lower) and mid.lower() != model_lower:
|
||||
return Suggestion(mid[len(model_part):])
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Stage 1: suggest provider name with :
|
||||
for pid in sorted(providers, key=lambda p: (p == current_prov, p)):
|
||||
candidate = f"{pid}:"
|
||||
if candidate.lower().startswith(sub_lower) and candidate.lower() != sub_lower:
|
||||
return Suggestion(candidate[len(sub_text):])
|
||||
|
||||
# Static subcommands
|
||||
if base_cmd in SUBCOMMANDS and SUBCOMMANDS[base_cmd]:
|
||||
if " " not in sub_text:
|
||||
for sub in SUBCOMMANDS[base_cmd]:
|
||||
if sub.startswith(sub_lower) and sub != sub_lower:
|
||||
return Suggestion(sub[len(sub_text):])
|
||||
|
||||
# Fall back to history
|
||||
if self._history:
|
||||
return self._history.get_suggestion(buffer, document)
|
||||
return None
|
||||
|
||||
|
||||
def _file_size_label(path: str) -> str:
|
||||
"""Return a compact human-readable file size, or '' on error."""
|
||||
try:
|
||||
size = os.path.getsize(path)
|
||||
except OSError:
|
||||
return ""
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
if size < 1024 * 1024:
|
||||
return f"{size / 1024:.0f}K"
|
||||
if size < 1024 * 1024 * 1024:
|
||||
return f"{size / (1024 * 1024):.1f}M"
|
||||
return f"{size / (1024 * 1024 * 1024):.1f}G"
|
||||
|
||||
+426
-12
@@ -16,7 +16,6 @@ import os
|
||||
import platform
|
||||
import re
|
||||
import stat
|
||||
import sys
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -25,6 +24,21 @@ from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
# Env var names written to .env that aren't in OPTIONAL_ENV_VARS
|
||||
# (managed by setup/provider flows directly).
|
||||
_EXTRA_ENV_KEYS = frozenset({
|
||||
"OPENAI_API_KEY", "OPENAI_BASE_URL",
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"DISCORD_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL",
|
||||
"SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
|
||||
"MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE",
|
||||
"MATRIX_PASSWORD", "MATRIX_ENCRYPTION", "MATRIX_HOME_ROOM",
|
||||
})
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -106,6 +120,7 @@ DEFAULT_CONFIG = {
|
||||
"cwd": ".", # Use current directory
|
||||
"timeout": 180,
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_forward_env": [],
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
@@ -118,6 +133,14 @@ DEFAULT_CONFIG = {
|
||||
# Each entry is "host_path:container_path" (standard Docker -v syntax).
|
||||
# Example: ["/home/user/projects:/workspace/projects", "/data:/data"]
|
||||
"docker_volumes": [],
|
||||
# Explicit opt-in: mount the host cwd into /workspace for Docker sessions.
|
||||
# Default off because passing host directories into a sandbox weakens isolation.
|
||||
"docker_mount_cwd_to_workspace": False,
|
||||
# Persistent shell — keep a long-lived bash shell across execute() calls
|
||||
# so cwd/env vars/shell variables survive between commands.
|
||||
# Enabled by default for non-local backends (SSH); local is always opt-in
|
||||
# via TERMINAL_LOCAL_PERSISTENT env var.
|
||||
"persistent_shell": True,
|
||||
},
|
||||
|
||||
"browser": {
|
||||
@@ -129,15 +152,22 @@ DEFAULT_CONFIG = {
|
||||
# When enabled, the agent takes a snapshot of the working directory once per
|
||||
# conversation turn (on first write_file/patch call). Use /rollback to restore.
|
||||
"checkpoints": {
|
||||
"enabled": False,
|
||||
"enabled": True,
|
||||
"max_snapshots": 50, # Max checkpoints to keep per directory
|
||||
},
|
||||
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.50,
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"summary_model": "", # empty = use main configured model
|
||||
"summary_provider": "auto",
|
||||
"summary_base_url": None,
|
||||
},
|
||||
"smart_model_routing": {
|
||||
"enabled": False,
|
||||
"max_simple_chars": 160,
|
||||
"max_simple_words": 28,
|
||||
"cheap_model": {},
|
||||
},
|
||||
|
||||
# Auxiliary model config — provider:model for each side task.
|
||||
@@ -152,6 +182,7 @@ DEFAULT_CONFIG = {
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
"timeout": 30, # seconds — increase for slow local vision models
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "auto",
|
||||
@@ -177,6 +208,12 @@ DEFAULT_CONFIG = {
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"approval": {
|
||||
"provider": "auto",
|
||||
"model": "", # fast/cheap model recommended (e.g. gemini-flash, haiku)
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"mcp": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
@@ -197,12 +234,19 @@ DEFAULT_CONFIG = {
|
||||
"resume_display": "full",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
},
|
||||
|
||||
# Privacy settings
|
||||
"privacy": {
|
||||
"redact_pii": False, # When True, hash user IDs and strip phone numbers from LLM context
|
||||
},
|
||||
|
||||
# Text-to-speech configuration
|
||||
"tts": {
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai"
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" | "neutts" (local)
|
||||
"edge": {
|
||||
"voice": "en-US-AriaNeural",
|
||||
# Popular: AriaNeural, JennyNeural, AndrewNeural, BrianNeural, SoniaNeural
|
||||
@@ -216,6 +260,12 @@ DEFAULT_CONFIG = {
|
||||
"voice": "alloy",
|
||||
# Voices: alloy, echo, fable, onyx, nova, shimmer
|
||||
},
|
||||
"neutts": {
|
||||
"ref_audio": "", # Path to reference voice audio (empty = bundled default)
|
||||
"ref_text": "", # Path to reference voice transcript (empty = bundled default)
|
||||
"model": "neuphonic/neutts-air-q4-gguf", # HuggingFace model repo
|
||||
"device": "cpu", # cpu, cuda, or mps
|
||||
},
|
||||
},
|
||||
|
||||
"stt": {
|
||||
@@ -283,6 +333,22 @@ DEFAULT_CONFIG = {
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
},
|
||||
|
||||
# WhatsApp platform settings (gateway mode)
|
||||
"whatsapp": {
|
||||
# Reply prefix prepended to every outgoing WhatsApp message.
|
||||
# Default (None) uses the built-in "⚕ *Hermes Agent*" header.
|
||||
# Set to "" (empty string) to disable the header entirely.
|
||||
# Supports \n for newlines, e.g. "🤖 *My Bot*\n──────\n"
|
||||
},
|
||||
|
||||
# Approval mode for dangerous commands:
|
||||
# manual — always prompt the user (default)
|
||||
# smart — use auxiliary LLM to auto-approve low-risk commands, prompt for high-risk
|
||||
# off — skip all approval prompts (equivalent to --yolo)
|
||||
"approvals": {
|
||||
"mode": "manual",
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
# User-defined quick commands that bypass the agent loop (type: exec only)
|
||||
@@ -299,10 +365,15 @@ DEFAULT_CONFIG = {
|
||||
"tirith_path": "tirith",
|
||||
"tirith_timeout": 5,
|
||||
"tirith_fail_open": True,
|
||||
"website_blocklist": {
|
||||
"enabled": False,
|
||||
"domains": [],
|
||||
"shared_files": [],
|
||||
},
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 8,
|
||||
"_config_version": 10,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -316,6 +387,7 @@ ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
|
||||
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
|
||||
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
|
||||
10: ["TAVILY_API_KEY"],
|
||||
}
|
||||
|
||||
# Required environment variables with metadata for migration prompts.
|
||||
@@ -424,8 +496,77 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"DEEPSEEK_API_KEY": {
|
||||
"description": "DeepSeek API key for direct DeepSeek access",
|
||||
"prompt": "DeepSeek API Key",
|
||||
"url": "https://platform.deepseek.com/api_keys",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"DEEPSEEK_BASE_URL": {
|
||||
"description": "Custom DeepSeek API base URL (advanced)",
|
||||
"prompt": "DeepSeek Base URL",
|
||||
"url": "",
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
},
|
||||
"DASHSCOPE_API_KEY": {
|
||||
"description": "Alibaba Cloud DashScope API key for Qwen models",
|
||||
"prompt": "DashScope API Key",
|
||||
"url": "https://modelstudio.console.alibabacloud.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"DASHSCOPE_BASE_URL": {
|
||||
"description": "Custom DashScope base URL (default: international endpoint)",
|
||||
"prompt": "DashScope Base URL",
|
||||
"url": "",
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"OPENCODE_ZEN_API_KEY": {
|
||||
"description": "OpenCode Zen API key (pay-as-you-go access to curated models)",
|
||||
"prompt": "OpenCode Zen API key",
|
||||
"url": "https://opencode.ai/auth",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"OPENCODE_ZEN_BASE_URL": {
|
||||
"description": "OpenCode Zen base URL override",
|
||||
"prompt": "OpenCode Zen base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"OPENCODE_GO_API_KEY": {
|
||||
"description": "OpenCode Go API key ($10/month subscription for open models)",
|
||||
"prompt": "OpenCode Go API key",
|
||||
"url": "https://opencode.ai/auth",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"OPENCODE_GO_BASE_URL": {
|
||||
"description": "OpenCode Go base URL override",
|
||||
"prompt": "OpenCode Go base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"PARALLEL_API_KEY": {
|
||||
"description": "Parallel API key for AI-native web search and extract",
|
||||
"prompt": "Parallel API key",
|
||||
"url": "https://parallel.ai/",
|
||||
"tools": ["web_search", "web_extract"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"FIRECRAWL_API_KEY": {
|
||||
"description": "Firecrawl API key for web search and scraping",
|
||||
"prompt": "Firecrawl API key",
|
||||
@@ -442,6 +583,14 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
"advanced": True,
|
||||
},
|
||||
"TAVILY_API_KEY": {
|
||||
"description": "Tavily API key for AI-native web search, extract, and crawl",
|
||||
"prompt": "Tavily API key",
|
||||
"url": "https://app.tavily.com/home",
|
||||
"tools": ["web_search", "web_extract", "web_crawl"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browserbase API key",
|
||||
@@ -458,6 +607,14 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": False,
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSER_USE_API_KEY": {
|
||||
"description": "Browser Use API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browser Use API key",
|
||||
"url": "https://browser-use.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"FAL_KEY": {
|
||||
"description": "FAL API key for image generation",
|
||||
"prompt": "FAL API key",
|
||||
@@ -514,6 +671,11 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"HONCHO_BASE_URL": {
|
||||
"description": "Base URL for self-hosted Honcho instances (no API key needed)",
|
||||
"prompt": "Honcho base URL (e.g. http://localhost:8000)",
|
||||
"category": "tool",
|
||||
},
|
||||
|
||||
# ── Messaging platforms ──
|
||||
"TELEGRAM_BOT_TOKEN": {
|
||||
@@ -562,6 +724,55 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATTERMOST_URL": {
|
||||
"description": "Mattermost server URL (e.g. https://mm.example.com)",
|
||||
"prompt": "Mattermost server URL",
|
||||
"url": "https://mattermost.com/deploy/",
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATTERMOST_TOKEN": {
|
||||
"description": "Mattermost bot token or personal access token",
|
||||
"prompt": "Mattermost bot token",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATTERMOST_ALLOWED_USERS": {
|
||||
"description": "Comma-separated Mattermost user IDs allowed to use the bot",
|
||||
"prompt": "Allowed Mattermost user IDs (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATRIX_HOMESERVER": {
|
||||
"description": "Matrix homeserver URL (e.g. https://matrix.example.org)",
|
||||
"prompt": "Matrix homeserver URL",
|
||||
"url": "https://matrix.org/ecosystem/servers/",
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATRIX_ACCESS_TOKEN": {
|
||||
"description": "Matrix access token (preferred over password login)",
|
||||
"prompt": "Matrix access token",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATRIX_USER_ID": {
|
||||
"description": "Matrix user ID (e.g. @hermes:example.org)",
|
||||
"prompt": "Matrix user ID (@user:server)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATRIX_ALLOWED_USERS": {
|
||||
"description": "Comma-separated Matrix user IDs allowed to use the bot (@user:server format)",
|
||||
"prompt": "Allowed Matrix user IDs (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"GATEWAY_ALLOW_ALL_USERS": {
|
||||
"description": "Allow all users to interact with messaging bots (true/false). Default: false.",
|
||||
"prompt": "Allow all users (true/false)",
|
||||
@@ -570,6 +781,59 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_ENABLED": {
|
||||
"description": "Enable the OpenAI-compatible API server (true/false). Allows frontends like Open WebUI, LobeChat, etc. to connect.",
|
||||
"prompt": "Enable API server (true/false)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"description": "Bearer token for API server authentication. If empty, all requests are allowed (local use only).",
|
||||
"prompt": "API server auth key (optional)",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_PORT": {
|
||||
"description": "Port for the API server (default: 8642).",
|
||||
"prompt": "API server port",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_HOST": {
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — requires API_SERVER_KEY for security.",
|
||||
"prompt": "API server host",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"WEBHOOK_ENABLED": {
|
||||
"description": "Enable the webhook platform adapter for receiving events from GitHub, GitLab, etc.",
|
||||
"prompt": "Enable webhooks (true/false)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"WEBHOOK_PORT": {
|
||||
"description": "Port for the webhook HTTP server (default: 8644).",
|
||||
"prompt": "Webhook port",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"WEBHOOK_SECRET": {
|
||||
"description": "Global HMAC secret for webhook signature validation (overridable per route in config.yaml).",
|
||||
"prompt": "Webhook secret",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
|
||||
# ── Agent settings ──
|
||||
"MESSAGING_CWD": {
|
||||
@@ -716,7 +980,15 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
Dict with migration results: {"env_added": [...], "config_added": [...], "warnings": [...]}
|
||||
"""
|
||||
results = {"env_added": [], "config_added": [], "warnings": []}
|
||||
|
||||
|
||||
# ── Always: sanitize .env (split concatenated keys) ──
|
||||
try:
|
||||
fixes = sanitize_env_file()
|
||||
if fixes and not quiet:
|
||||
print(f" ✓ Repaired .env file ({fixes} corrupted entries fixed)")
|
||||
except Exception:
|
||||
pass # best-effort; don't block migration on sanitize failure
|
||||
|
||||
# Check config version
|
||||
current_ver, latest_ver = check_config_version()
|
||||
|
||||
@@ -759,6 +1031,18 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
tz_display = config["timezone"] or "(server-local)"
|
||||
print(f" ✓ Added timezone to config.yaml: {tz_display}")
|
||||
|
||||
# ── Version 8 → 9: clear ANTHROPIC_TOKEN from .env ──
|
||||
# The new Anthropic auth flow no longer uses this env var.
|
||||
if current_ver < 9:
|
||||
try:
|
||||
old_token = get_env_value("ANTHROPIC_TOKEN")
|
||||
if old_token:
|
||||
save_env_value("ANTHROPIC_TOKEN", "")
|
||||
if not quiet:
|
||||
print(" ✓ Cleared ANTHROPIC_TOKEN from .env (no longer used)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
@@ -968,6 +1252,19 @@ _FALLBACK_COMMENT = """
|
||||
# fallback_model:
|
||||
# provider: openrouter
|
||||
# model: anthropic/claude-sonnet-4
|
||||
#
|
||||
# ── Smart Model Routing ────────────────────────────────────────────────
|
||||
# Optional cheap-vs-strong routing for simple turns.
|
||||
# Keeps the primary model for complex work, but can route short/simple
|
||||
# messages to a cheaper model across providers.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
"""
|
||||
|
||||
|
||||
@@ -998,6 +1295,19 @@ _COMMENTED_SECTIONS = """
|
||||
# fallback_model:
|
||||
# provider: openrouter
|
||||
# model: anthropic/claude-sonnet-4
|
||||
#
|
||||
# ── Smart Model Routing ────────────────────────────────────────────────
|
||||
# Optional cheap-vs-strong routing for simple turns.
|
||||
# Keeps the primary model for complex work, but can route short/simple
|
||||
# messages to a cheaper model across providers.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
"""
|
||||
|
||||
|
||||
@@ -1046,6 +1356,102 @@ def load_env() -> Dict[str, str]:
|
||||
return env_vars
|
||||
|
||||
|
||||
def _sanitize_env_lines(lines: list) -> list:
|
||||
"""Fix corrupted .env lines before writing.
|
||||
|
||||
Handles two known corruption patterns:
|
||||
1. Concatenated KEY=VALUE pairs on a single line (missing newline between
|
||||
entries, e.g. ``ANTHROPIC_API_KEY=sk-...OPENAI_BASE_URL=https://...``).
|
||||
2. Stale ``KEY=***`` placeholder entries left by incomplete setup runs.
|
||||
|
||||
Uses a known-keys set (OPTIONAL_ENV_VARS + _EXTRA_ENV_KEYS) so we only
|
||||
split on real Hermes env var names, avoiding false positives from values
|
||||
that happen to contain uppercase text with ``=``.
|
||||
"""
|
||||
# Build the known keys set lazily from OPTIONAL_ENV_VARS + extras.
|
||||
# Done inside the function so OPTIONAL_ENV_VARS is guaranteed to be defined.
|
||||
known_keys = set(OPTIONAL_ENV_VARS.keys()) | _EXTRA_ENV_KEYS
|
||||
|
||||
sanitized: list[str] = []
|
||||
for line in lines:
|
||||
raw = line.rstrip("\r\n")
|
||||
stripped = raw.strip()
|
||||
|
||||
# Preserve blank lines and comments
|
||||
if not stripped or stripped.startswith("#"):
|
||||
sanitized.append(raw + "\n")
|
||||
continue
|
||||
|
||||
# Detect concatenated KEY=VALUE pairs on one line.
|
||||
# Search for known KEY= patterns at any position in the line.
|
||||
split_positions = []
|
||||
for key_name in known_keys:
|
||||
needle = key_name + "="
|
||||
idx = stripped.find(needle)
|
||||
while idx >= 0:
|
||||
split_positions.append(idx)
|
||||
idx = stripped.find(needle, idx + len(needle))
|
||||
|
||||
if len(split_positions) > 1:
|
||||
split_positions.sort()
|
||||
# Deduplicate (shouldn't happen, but be safe)
|
||||
split_positions = sorted(set(split_positions))
|
||||
for i, pos in enumerate(split_positions):
|
||||
end = split_positions[i + 1] if i + 1 < len(split_positions) else len(stripped)
|
||||
part = stripped[pos:end].strip()
|
||||
if part:
|
||||
sanitized.append(part + "\n")
|
||||
else:
|
||||
sanitized.append(stripped + "\n")
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_env_file() -> int:
|
||||
"""Read, sanitize, and rewrite ~/.hermes/.env in place.
|
||||
|
||||
Returns the number of lines that were fixed (concatenation splits +
|
||||
placeholder removals). Returns 0 when no changes are needed.
|
||||
"""
|
||||
env_path = get_env_path()
|
||||
if not env_path.exists():
|
||||
return 0
|
||||
|
||||
read_kw = {"encoding": "utf-8", "errors": "replace"} if _IS_WINDOWS else {}
|
||||
write_kw = {"encoding": "utf-8"} if _IS_WINDOWS else {}
|
||||
|
||||
with open(env_path, **read_kw) as f:
|
||||
original_lines = f.readlines()
|
||||
|
||||
sanitized = _sanitize_env_lines(original_lines)
|
||||
|
||||
if sanitized == original_lines:
|
||||
return 0
|
||||
|
||||
# Count fixes: difference in line count (from splits) + removed lines
|
||||
fixes = abs(len(sanitized) - len(original_lines))
|
||||
if fixes == 0:
|
||||
# Lines changed content (e.g. *** removal) even if count is same
|
||||
fixes = sum(1 for a, b in zip(original_lines, sanitized) if a != b)
|
||||
fixes += abs(len(sanitized) - len(original_lines))
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(env_path.parent), suffix=".tmp", prefix=".env_")
|
||||
try:
|
||||
with os.fdopen(fd, "w", **write_kw) as f:
|
||||
f.writelines(sanitized)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, env_path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
_secure_file(env_path)
|
||||
return fixes
|
||||
|
||||
|
||||
def save_env_value(key: str, value: str):
|
||||
"""Save or update a value in ~/.hermes/.env."""
|
||||
if not _ENV_VAR_NAME_RE.match(key):
|
||||
@@ -1063,6 +1469,8 @@ def save_env_value(key: str, value: str):
|
||||
if env_path.exists():
|
||||
with open(env_path, **read_kw) as f:
|
||||
lines = f.readlines()
|
||||
# Sanitize on every read: split concatenated keys, drop stale placeholders
|
||||
lines = _sanitize_env_lines(lines)
|
||||
|
||||
# Find and update or append
|
||||
found = False
|
||||
@@ -1181,8 +1589,11 @@ def show_config():
|
||||
keys = [
|
||||
("OPENROUTER_API_KEY", "OpenRouter"),
|
||||
("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"),
|
||||
("PARALLEL_API_KEY", "Parallel"),
|
||||
("FIRECRAWL_API_KEY", "Firecrawl"),
|
||||
("TAVILY_API_KEY", "Tavily"),
|
||||
("BROWSERBASE_API_KEY", "Browserbase"),
|
||||
("BROWSER_USE_API_KEY", "Browser Use"),
|
||||
("FAL_KEY", "FAL"),
|
||||
]
|
||||
|
||||
@@ -1197,7 +1608,6 @@ def show_config():
|
||||
print(color("◆ Model", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Model: {config.get('model', 'not set')}")
|
||||
print(f" Max turns: {config.get('agent', {}).get('max_turns', DEFAULT_CONFIG['agent']['max_turns'])}")
|
||||
print(f" Toolsets: {', '.join(config.get('toolsets', ['all']))}")
|
||||
|
||||
# Display
|
||||
print()
|
||||
@@ -1216,11 +1626,11 @@ def show_config():
|
||||
print(f" Timeout: {terminal.get('timeout', 60)}s")
|
||||
|
||||
if terminal.get('backend') == 'docker':
|
||||
print(f" Docker image: {terminal.get('docker_image', 'python:3.11-slim')}")
|
||||
print(f" Docker image: {terminal.get('docker_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'singularity':
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://python:3.11')}")
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'modal':
|
||||
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
|
||||
print(f" Modal image: {terminal.get('modal_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
modal_token = get_env_value('MODAL_TOKEN_ID')
|
||||
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
|
||||
elif terminal.get('backend') == 'daytona':
|
||||
@@ -1250,7 +1660,8 @@ def show_config():
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.50) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
_sm = compression.get('summary_model', '') or '(main model)'
|
||||
print(f" Model: {_sm}")
|
||||
comp_provider = compression.get('summary_provider', 'auto')
|
||||
if comp_provider != 'auto':
|
||||
print(f" Provider: {comp_provider}")
|
||||
@@ -1329,7 +1740,8 @@ def set_config_value(key: str, value: str):
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
|
||||
'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'TAVILY_API_KEY',
|
||||
'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
@@ -1388,9 +1800,11 @@ def set_config_value(key: str, value: str):
|
||||
"terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"terminal.modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"terminal.docker_mount_cwd_to_workspace": "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE",
|
||||
"terminal.cwd": "TERMINAL_CWD",
|
||||
"terminal.timeout": "TERMINAL_TIMEOUT",
|
||||
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
"terminal.persistent_shell": "TERMINAL_PERSISTENT_SHELL",
|
||||
}
|
||||
if key in _config_to_env_sync:
|
||||
save_env_value(_config_to_env_sync[key], str(value))
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
"""GitHub Copilot authentication utilities.
|
||||
|
||||
Implements the OAuth device code flow used by the Copilot CLI and handles
|
||||
token validation/exchange for the Copilot API.
|
||||
|
||||
Token type support (per GitHub docs):
|
||||
gho_ OAuth token ✓ (default via copilot login)
|
||||
github_pat_ Fine-grained PAT ✓ (needs Copilot Requests permission)
|
||||
ghu_ GitHub App token ✓ (via environment variable)
|
||||
ghp_ Classic PAT ✗ NOT SUPPORTED
|
||||
|
||||
Credential search order (matching Copilot CLI behaviour):
|
||||
1. COPILOT_GITHUB_TOKEN env var
|
||||
2. GH_TOKEN env var
|
||||
3. GITHUB_TOKEN env var
|
||||
4. gh auth token CLI fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth device code flow constants (same client ID as opencode/Copilot CLI)
|
||||
COPILOT_OAUTH_CLIENT_ID = "Ov23li8tweQw6odWQebz"
|
||||
COPILOT_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||
COPILOT_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
|
||||
# Copilot API constants
|
||||
COPILOT_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
COPILOT_API_BASE_URL = "https://api.githubcopilot.com"
|
||||
|
||||
# Token type prefixes
|
||||
_CLASSIC_PAT_PREFIX = "ghp_"
|
||||
_SUPPORTED_PREFIXES = ("gho_", "github_pat_", "ghu_")
|
||||
|
||||
# Env var search order (matches Copilot CLI)
|
||||
COPILOT_ENV_VARS = ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
|
||||
|
||||
# Polling constants
|
||||
_DEVICE_CODE_POLL_INTERVAL = 5 # seconds
|
||||
_DEVICE_CODE_POLL_SAFETY_MARGIN = 3 # seconds
|
||||
|
||||
|
||||
def is_classic_pat(token: str) -> bool:
|
||||
"""Check if a token is a classic PAT (ghp_*), which Copilot doesn't support."""
|
||||
return token.strip().startswith(_CLASSIC_PAT_PREFIX)
|
||||
|
||||
|
||||
def validate_copilot_token(token: str) -> tuple[bool, str]:
|
||||
"""Validate that a token is usable with the Copilot API.
|
||||
|
||||
Returns (valid, message).
|
||||
"""
|
||||
token = token.strip()
|
||||
if not token:
|
||||
return False, "Empty token"
|
||||
|
||||
if token.startswith(_CLASSIC_PAT_PREFIX):
|
||||
return False, (
|
||||
"Classic Personal Access Tokens (ghp_*) are not supported by the "
|
||||
"Copilot API. Use one of:\n"
|
||||
" → `copilot login` or `hermes model` to authenticate via OAuth\n"
|
||||
" → A fine-grained PAT (github_pat_*) with Copilot Requests permission\n"
|
||||
" → `gh auth login` with the default device code flow (produces gho_* tokens)"
|
||||
)
|
||||
|
||||
return True, "OK"
|
||||
|
||||
|
||||
def resolve_copilot_token() -> tuple[str, str]:
|
||||
"""Resolve a GitHub token suitable for Copilot API use.
|
||||
|
||||
Returns (token, source) where source describes where the token came from.
|
||||
Raises ValueError if only a classic PAT is available.
|
||||
"""
|
||||
# 1. Check env vars in priority order
|
||||
for env_var in COPILOT_ENV_VARS:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
valid, msg = validate_copilot_token(val)
|
||||
if not valid:
|
||||
logger.warning(
|
||||
"Token from %s is not supported: %s", env_var, msg
|
||||
)
|
||||
continue
|
||||
return val, env_var
|
||||
|
||||
# 2. Fall back to gh auth token
|
||||
token = _try_gh_cli_token()
|
||||
if token:
|
||||
valid, msg = validate_copilot_token(token)
|
||||
if not valid:
|
||||
raise ValueError(
|
||||
f"Token from `gh auth token` is a classic PAT (ghp_*). {msg}"
|
||||
)
|
||||
return token, "gh auth token"
|
||||
|
||||
return "", ""
|
||||
|
||||
|
||||
def _gh_cli_candidates() -> list[str]:
|
||||
"""Return candidate ``gh`` binary paths, including common Homebrew installs."""
|
||||
candidates: list[str] = []
|
||||
|
||||
resolved = shutil.which("gh")
|
||||
if resolved:
|
||||
candidates.append(resolved)
|
||||
|
||||
for candidate in (
|
||||
"/opt/homebrew/bin/gh",
|
||||
"/usr/local/bin/gh",
|
||||
str(Path.home() / ".local" / "bin" / "gh"),
|
||||
):
|
||||
if candidate in candidates:
|
||||
continue
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
candidates.append(candidate)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _try_gh_cli_token() -> Optional[str]:
|
||||
"""Return a token from ``gh auth token`` when the GitHub CLI is available."""
|
||||
for gh_path in _gh_cli_candidates():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[gh_path, "auth", "token"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired) as exc:
|
||||
logger.debug("gh CLI token lookup failed (%s): %s", gh_path, exc)
|
||||
continue
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
return None
|
||||
|
||||
|
||||
# ─── OAuth Device Code Flow ────────────────────────────────────────────────
|
||||
|
||||
def copilot_device_code_login(
|
||||
*,
|
||||
host: str = "github.com",
|
||||
timeout_seconds: float = 300,
|
||||
) -> Optional[str]:
|
||||
"""Run the GitHub OAuth device code flow for Copilot.
|
||||
|
||||
Prints instructions for the user, polls for completion, and returns
|
||||
the OAuth access token on success, or None on failure/cancellation.
|
||||
|
||||
This replicates the flow used by opencode and the Copilot CLI.
|
||||
"""
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
|
||||
domain = host.rstrip("/")
|
||||
device_code_url = f"https://{domain}/login/device/code"
|
||||
access_token_url = f"https://{domain}/login/oauth/access_token"
|
||||
|
||||
# Step 1: Request device code
|
||||
data = urllib.parse.urlencode({
|
||||
"client_id": COPILOT_OAUTH_CLIENT_ID,
|
||||
"scope": "read:user",
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
device_code_url,
|
||||
data=data,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": "HermesAgent/1.0",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
device_data = json.loads(resp.read().decode())
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initiate device authorization: %s", exc)
|
||||
print(f" ✗ Failed to start device authorization: {exc}")
|
||||
return None
|
||||
|
||||
verification_uri = device_data.get("verification_uri", "https://github.com/login/device")
|
||||
user_code = device_data.get("user_code", "")
|
||||
device_code = device_data.get("device_code", "")
|
||||
interval = max(device_data.get("interval", _DEVICE_CODE_POLL_INTERVAL), 1)
|
||||
|
||||
if not device_code or not user_code:
|
||||
print(" ✗ GitHub did not return a device code.")
|
||||
return None
|
||||
|
||||
# Step 2: Show instructions
|
||||
print()
|
||||
print(f" Open this URL in your browser: {verification_uri}")
|
||||
print(f" Enter this code: {user_code}")
|
||||
print()
|
||||
print(" Waiting for authorization...", end="", flush=True)
|
||||
|
||||
# Step 3: Poll for completion
|
||||
deadline = time.time() + timeout_seconds
|
||||
|
||||
while time.time() < deadline:
|
||||
time.sleep(interval + _DEVICE_CODE_POLL_SAFETY_MARGIN)
|
||||
|
||||
poll_data = urllib.parse.urlencode({
|
||||
"client_id": COPILOT_OAUTH_CLIENT_ID,
|
||||
"device_code": device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}).encode()
|
||||
|
||||
poll_req = urllib.request.Request(
|
||||
access_token_url,
|
||||
data=poll_data,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": "HermesAgent/1.0",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(poll_req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
except Exception:
|
||||
print(".", end="", flush=True)
|
||||
continue
|
||||
|
||||
if result.get("access_token"):
|
||||
print(" ✓")
|
||||
return result["access_token"]
|
||||
|
||||
error = result.get("error", "")
|
||||
if error == "authorization_pending":
|
||||
print(".", end="", flush=True)
|
||||
continue
|
||||
elif error == "slow_down":
|
||||
# RFC 8628: add 5 seconds to polling interval
|
||||
server_interval = result.get("interval")
|
||||
if isinstance(server_interval, (int, float)) and server_interval > 0:
|
||||
interval = int(server_interval)
|
||||
else:
|
||||
interval += 5
|
||||
print(".", end="", flush=True)
|
||||
continue
|
||||
elif error == "expired_token":
|
||||
print()
|
||||
print(" ✗ Device code expired. Please try again.")
|
||||
return None
|
||||
elif error == "access_denied":
|
||||
print()
|
||||
print(" ✗ Authorization was denied.")
|
||||
return None
|
||||
elif error:
|
||||
print()
|
||||
print(f" ✗ Authorization failed: {error}")
|
||||
return None
|
||||
|
||||
print()
|
||||
print(" ✗ Timed out waiting for authorization.")
|
||||
return None
|
||||
|
||||
|
||||
# ─── Copilot API Headers ───────────────────────────────────────────────────
|
||||
|
||||
def copilot_request_headers(
|
||||
*,
|
||||
is_agent_turn: bool = True,
|
||||
is_vision: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Build the standard headers for Copilot API requests.
|
||||
|
||||
Replicates the header set used by opencode and the Copilot CLI.
|
||||
"""
|
||||
headers: dict[str, str] = {
|
||||
"Editor-Version": "vscode/1.104.1",
|
||||
"User-Agent": "HermesAgent/1.0",
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"x-initiator": "agent" if is_agent_turn else "user",
|
||||
}
|
||||
if is_vision:
|
||||
headers["Copilot-Vision-Request"] = "true"
|
||||
|
||||
return headers
|
||||
@@ -46,6 +46,7 @@ _PROVIDER_ENV_HINTS = (
|
||||
"KIMI_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
"KILOCODE_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@@ -570,6 +571,8 @@ def run_doctor(args):
|
||||
# MiniMax APIs don't support /models endpoint — https://github.com/NousResearch/hermes-agent/issues/811
|
||||
("MiniMax", ("MINIMAX_API_KEY",), None, "MINIMAX_BASE_URL", False),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), None, "MINIMAX_CN_BASE_URL", False),
|
||||
("AI Gateway", ("AI_GATEWAY_API_KEY",), "https://ai-gateway.vercel.sh/v1/models", "AI_GATEWAY_BASE_URL", True),
|
||||
("Kilo Code", ("KILOCODE_API_KEY",), "https://api.kilo.ai/api/gateway/models", "KILOCODE_BASE_URL", True),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env, _supports_health_check in _apikey_providers:
|
||||
_key = ""
|
||||
@@ -714,13 +717,14 @@ def run_doctor(args):
|
||||
print(color("◆ Honcho Memory", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from honcho_integration.client import HonchoClientConfig, GLOBAL_CONFIG_PATH
|
||||
from honcho_integration.client import HonchoClientConfig, resolve_config_path
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
_honcho_cfg_path = resolve_config_path()
|
||||
|
||||
if not GLOBAL_CONFIG_PATH.exists():
|
||||
if not _honcho_cfg_path.exists():
|
||||
check_warn("Honcho config not found", f"run: hermes honcho setup")
|
||||
elif not hcfg.enabled:
|
||||
check_info("Honcho disabled (set enabled: true in ~/.honcho/config.json to activate)")
|
||||
check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)")
|
||||
elif not hcfg.api_key:
|
||||
check_fail("Honcho API key not set", "run: hermes honcho setup")
|
||||
issues.append("No Honcho API key — run 'hermes honcho setup'")
|
||||
|
||||
+334
-32
@@ -6,6 +6,7 @@ Handles: hermes gateway [run|start|stop|restart|status|install|uninstall|setup]
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -30,6 +31,7 @@ def find_gateway_pids() -> list:
|
||||
pids = []
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
]
|
||||
@@ -119,17 +121,62 @@ def is_windows() -> bool:
|
||||
# Service Configuration
|
||||
# =============================================================================
|
||||
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
_SERVICE_BASE = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
|
||||
def get_service_name() -> str:
|
||||
"""Derive a systemd service name scoped to this HERMES_HOME.
|
||||
|
||||
Default ``~/.hermes`` returns ``hermes-gateway`` (backward compatible).
|
||||
Any other HERMES_HOME appends a short hash so multiple installations
|
||||
can each have their own systemd service without conflicting.
|
||||
"""
|
||||
import hashlib
|
||||
from pathlib import Path as _Path # local import to avoid monkeypatch interference
|
||||
home = _Path(os.getenv("HERMES_HOME", _Path.home() / ".hermes")).resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
if home == default:
|
||||
return _SERVICE_BASE
|
||||
suffix = hashlib.sha256(str(home).encode()).hexdigest()[:8]
|
||||
return f"{_SERVICE_BASE}-{suffix}"
|
||||
|
||||
|
||||
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
|
||||
|
||||
|
||||
def get_systemd_unit_path(system: bool = False) -> Path:
|
||||
name = get_service_name()
|
||||
if system:
|
||||
return Path("/etc/systemd/system") / f"{SERVICE_NAME}.service"
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
return Path("/etc/systemd/system") / f"{name}.service"
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{name}.service"
|
||||
|
||||
|
||||
def _ensure_user_systemd_env() -> None:
|
||||
"""Ensure DBUS_SESSION_BUS_ADDRESS and XDG_RUNTIME_DIR are set for systemctl --user.
|
||||
|
||||
On headless servers (SSH sessions), these env vars may be missing even when
|
||||
the user's systemd instance is running (via linger). Without them,
|
||||
``systemctl --user`` fails with "Failed to connect to bus: No medium found".
|
||||
We detect the standard socket path and set the vars so all subsequent
|
||||
subprocess calls inherit them.
|
||||
"""
|
||||
uid = os.getuid()
|
||||
if "XDG_RUNTIME_DIR" not in os.environ:
|
||||
runtime_dir = f"/run/user/{uid}"
|
||||
if Path(runtime_dir).exists():
|
||||
os.environ["XDG_RUNTIME_DIR"] = runtime_dir
|
||||
|
||||
if "DBUS_SESSION_BUS_ADDRESS" not in os.environ:
|
||||
xdg_runtime = os.environ.get("XDG_RUNTIME_DIR", f"/run/user/{uid}")
|
||||
bus_path = Path(xdg_runtime) / "bus"
|
||||
if bus_path.exists():
|
||||
os.environ["DBUS_SESSION_BUS_ADDRESS"] = f"unix:path={bus_path}"
|
||||
|
||||
|
||||
def _systemctl_cmd(system: bool = False) -> list[str]:
|
||||
if not system:
|
||||
_ensure_user_systemd_env()
|
||||
return ["systemctl"] if system else ["systemctl", "--user"]
|
||||
|
||||
|
||||
@@ -350,17 +397,22 @@ def get_hermes_cli_path() -> str:
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str:
|
||||
import shutil
|
||||
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
|
||||
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
path_entries = [venv_bin, node_bin]
|
||||
resolved_node = shutil.which("node")
|
||||
if resolved_node:
|
||||
resolved_node_dir = str(Path(resolved_node).resolve().parent)
|
||||
if resolved_node_dir not in path_entries:
|
||||
path_entries.append(resolved_node_dir)
|
||||
path_entries.extend(["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"])
|
||||
sane_path = ":".join(path_entries)
|
||||
|
||||
hermes_home = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")).resolve())
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
@@ -368,6 +420,8 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
StartLimitIntervalSec=600
|
||||
StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
@@ -380,11 +434,12 @@ Environment="USER={username}"
|
||||
Environment="LOGNAME={username}"
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
RestartSec=30
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
TimeoutStopSec=60
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -395,19 +450,21 @@ WantedBy=multi-user.target
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
StartLimitIntervalSec=600
|
||||
StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
ExecStop={hermes_cli} gateway stop
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
RestartSec=30
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
TimeoutStopSec=60
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -455,7 +512,7 @@ def _print_linger_enable_warning(username: str, detail: str | None = None) -> No
|
||||
print(f" sudo loginctl enable-linger {username}")
|
||||
print()
|
||||
print(" Then restart the gateway:")
|
||||
print(f" systemctl --user restart {SERVICE_NAME}.service")
|
||||
print(f" systemctl --user restart {get_service_name()}.service")
|
||||
print()
|
||||
|
||||
|
||||
@@ -517,6 +574,12 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
if unit_path.exists() and not force:
|
||||
if not systemd_unit_is_current(system=system):
|
||||
print(f"↻ Repairing outdated {_service_scope_label(system)} systemd service at: {unit_path}")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", get_service_name()], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service definition updated")
|
||||
return
|
||||
print(f"Service already installed at: {unit_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
@@ -526,7 +589,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
unit_path.write_text(generate_systemd_unit(system=system, run_as_user=run_as_user), encoding="utf-8")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["daemon-reload"], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", SERVICE_NAME], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", get_service_name()], check=True)
|
||||
|
||||
print()
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service installed and enabled!")
|
||||
@@ -534,7 +597,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
print("Next steps:")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway start{scope_flag} # Start the service")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway status{scope_flag} # Check status")
|
||||
print(f" {'journalctl' if system else 'journalctl --user'} -u {SERVICE_NAME} -f # View logs")
|
||||
print(f" {'journalctl' if system else 'journalctl --user'} -u {get_service_name()} -f # View logs")
|
||||
print()
|
||||
|
||||
if system:
|
||||
@@ -552,8 +615,8 @@ def systemd_uninstall(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("uninstall")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["disable", SERVICE_NAME], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", get_service_name()], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["disable", get_service_name()], check=False)
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if unit_path.exists():
|
||||
@@ -569,7 +632,7 @@ def systemd_start(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("start")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["start", SERVICE_NAME], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["start", get_service_name()], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service started")
|
||||
|
||||
|
||||
@@ -578,7 +641,7 @@ def systemd_stop(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("stop")
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", get_service_name()], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service stopped")
|
||||
|
||||
|
||||
@@ -588,7 +651,7 @@ def systemd_restart(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("restart")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", SERVICE_NAME], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
|
||||
|
||||
@@ -613,12 +676,12 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
print()
|
||||
|
||||
subprocess.run(
|
||||
_systemctl_cmd(system) + ["status", SERVICE_NAME, "--no-pager"],
|
||||
_systemctl_cmd(system) + ["status", get_service_name(), "--no-pager"],
|
||||
capture_output=False,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(system) + ["is-active", SERVICE_NAME],
|
||||
_systemctl_cmd(system) + ["is-active", get_service_name()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
@@ -657,7 +720,7 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run(_journalctl_cmd(system) + ["-u", SERVICE_NAME, "-n", "20", "--no-pager"])
|
||||
subprocess.run(_journalctl_cmd(system) + ["-u", get_service_name(), "-n", "20", "--no-pager"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -684,6 +747,7 @@ def generate_launchd_plist() -> str:
|
||||
<string>hermes_cli.main</string>
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
<string>--replace</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
@@ -707,10 +771,45 @@ def generate_launchd_plist() -> str:
|
||||
</plist>
|
||||
"""
|
||||
|
||||
def launchd_plist_is_current() -> bool:
|
||||
"""Check if the installed launchd plist matches the currently generated one."""
|
||||
plist_path = get_launchd_plist_path()
|
||||
if not plist_path.exists():
|
||||
return False
|
||||
|
||||
installed = plist_path.read_text(encoding="utf-8")
|
||||
expected = generate_launchd_plist()
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
|
||||
|
||||
def refresh_launchd_plist_if_needed() -> bool:
|
||||
"""Rewrite the installed launchd plist when the generated definition has changed.
|
||||
|
||||
Unlike systemd, launchd picks up plist changes on the next ``launchctl stop``/
|
||||
``launchctl start`` cycle — no daemon-reload is needed. We still unload/reload
|
||||
to make launchd re-read the updated plist immediately.
|
||||
"""
|
||||
plist_path = get_launchd_plist_path()
|
||||
if not plist_path.exists() or launchd_plist_is_current():
|
||||
return False
|
||||
|
||||
plist_path.write_text(generate_launchd_plist(), encoding="utf-8")
|
||||
# Unload/reload so launchd picks up the new definition
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=False)
|
||||
print("↻ Updated gateway launchd service definition to match the current Hermes install")
|
||||
return True
|
||||
|
||||
|
||||
def launchd_install(force: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
if plist_path.exists() and not force:
|
||||
if not launchd_plist_is_current():
|
||||
print(f"↻ Repairing outdated launchd service at: {plist_path}")
|
||||
refresh_launchd_plist_if_needed()
|
||||
print("✓ Service definition updated")
|
||||
return
|
||||
print(f"Service already installed at: {plist_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
@@ -739,29 +838,94 @@ def launchd_uninstall():
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
def launchd_start():
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
refresh_launchd_plist_if_needed()
|
||||
plist_path = get_launchd_plist_path()
|
||||
try:
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode != 3 or not plist_path.exists():
|
||||
raise
|
||||
print("↻ launchd job was unloaded; reloading service definition")
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
def launchd_stop():
|
||||
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
"""Wait for the gateway process (by saved PID) to exit.
|
||||
|
||||
Uses the PID from the gateway.pid file — not launchd labels — so this
|
||||
works correctly when multiple gateway instances run under separate
|
||||
HERMES_HOME directories.
|
||||
|
||||
Args:
|
||||
timeout: Total seconds to wait before giving up.
|
||||
force_after: Seconds of graceful waiting before sending SIGKILL.
|
||||
"""
|
||||
import time
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
force_deadline = time.monotonic() + force_after
|
||||
force_sent = False
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
pid = get_running_pid()
|
||||
if pid is None:
|
||||
return # Process exited cleanly.
|
||||
|
||||
if not force_sent and time.monotonic() >= force_deadline:
|
||||
# Grace period expired — force-kill the specific PID.
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
|
||||
except (ProcessLookupError, PermissionError):
|
||||
return # Already gone or we can't touch it.
|
||||
force_sent = True
|
||||
|
||||
time.sleep(0.3)
|
||||
|
||||
# Timed out even after SIGKILL.
|
||||
remaining_pid = get_running_pid()
|
||||
if remaining_pid is not None:
|
||||
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
|
||||
|
||||
|
||||
def launchd_restart():
|
||||
launchd_stop()
|
||||
try:
|
||||
launchd_stop()
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode != 3:
|
||||
raise
|
||||
print("↻ launchd job was unloaded; skipping stop")
|
||||
_wait_for_gateway_exit()
|
||||
launchd_start()
|
||||
|
||||
def launchd_status(deep: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
print(f"Launchd plist: {plist_path}")
|
||||
if launchd_plist_is_current():
|
||||
print("✓ Service definition matches the current Hermes install")
|
||||
else:
|
||||
print("⚠ Service definition is stale relative to the current Hermes install")
|
||||
print(" Run: hermes gateway start")
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✓ Gateway service is loaded")
|
||||
print(result.stdout)
|
||||
else:
|
||||
print("✗ Gateway service is not loaded")
|
||||
print(" Service definition exists locally but launchd has not loaded it.")
|
||||
print(" Run: hermes gateway start")
|
||||
|
||||
if deep:
|
||||
log_file = get_hermes_home() / "logs" / "gateway.log"
|
||||
@@ -890,6 +1054,64 @@ _PLATFORMS = [
|
||||
"help": "Paste your member ID from step 7 above."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "matrix",
|
||||
"label": "Matrix",
|
||||
"emoji": "🔐",
|
||||
"token_var": "MATRIX_ACCESS_TOKEN",
|
||||
"setup_instructions": [
|
||||
"1. Works with any Matrix homeserver (self-hosted Synapse/Conduit/Dendrite or matrix.org)",
|
||||
"2. Create a bot user on your homeserver, or use your own account",
|
||||
"3. Get an access token: Element → Settings → Help & About → Access Token",
|
||||
" Or via API: curl -X POST https://your-server/_matrix/client/v3/login \\",
|
||||
" -d '{\"type\":\"m.login.password\",\"user\":\"@bot:server\",\"password\":\"...\"}'",
|
||||
"4. Alternatively, provide user ID + password and Hermes will log in directly",
|
||||
"5. For E2EE: set MATRIX_ENCRYPTION=true (requires pip install 'matrix-nio[e2e]')",
|
||||
"6. To find your user ID: it's @username:your-server (shown in Element profile)",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "MATRIX_HOMESERVER", "prompt": "Homeserver URL (e.g. https://matrix.example.org)", "password": False,
|
||||
"help": "Your Matrix homeserver URL. Works with any self-hosted instance."},
|
||||
{"name": "MATRIX_ACCESS_TOKEN", "prompt": "Access token (leave empty to use password login instead)", "password": True,
|
||||
"help": "Paste your access token, or leave empty and provide user ID + password below."},
|
||||
{"name": "MATRIX_USER_ID", "prompt": "User ID (@bot:server — required for password login)", "password": False,
|
||||
"help": "Full Matrix user ID, e.g. @hermes:matrix.example.org"},
|
||||
{"name": "MATRIX_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, e.g. @you:server)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Matrix user IDs who can interact with the bot."},
|
||||
{"name": "MATRIX_HOME_ROOM", "prompt": "Home room ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "Room ID (e.g. !abc123:server) for delivering cron results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "mattermost",
|
||||
"label": "Mattermost",
|
||||
"emoji": "💬",
|
||||
"token_var": "MATTERMOST_TOKEN",
|
||||
"setup_instructions": [
|
||||
"1. In Mattermost: Integrations → Bot Accounts → Add Bot Account",
|
||||
" (System Console → Integrations → Bot Accounts must be enabled)",
|
||||
"2. Give it a username (e.g. hermes) and copy the bot token",
|
||||
"3. Works with any self-hosted Mattermost instance — enter your server URL",
|
||||
"4. To find your user ID: click your avatar (top-left) → Profile",
|
||||
" Your user ID is displayed there — click it to copy.",
|
||||
" ⚠ This is NOT your username — it's a 26-character alphanumeric ID.",
|
||||
"5. To get a channel ID: click the channel name → View Info → copy the ID",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "MATTERMOST_URL", "prompt": "Server URL (e.g. https://mm.example.com)", "password": False,
|
||||
"help": "Your Mattermost server URL. Works with any self-hosted instance."},
|
||||
{"name": "MATTERMOST_TOKEN", "prompt": "Bot token", "password": True,
|
||||
"help": "Paste the bot token from step 2 above."},
|
||||
{"name": "MATTERMOST_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Your Mattermost user ID from step 4 above."},
|
||||
{"name": "MATTERMOST_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "Channel ID where Hermes delivers cron results and notifications."},
|
||||
{"name": "MATTERMOST_REPLY_MODE", "prompt": "Reply mode — 'off' for flat messages, 'thread' for threaded replies (default: off)", "password": False,
|
||||
"help": "off = flat channel messages, thread = replies nest under your message."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "whatsapp",
|
||||
"label": "WhatsApp",
|
||||
@@ -928,6 +1150,51 @@ _PLATFORMS = [
|
||||
"help": "Only emails from these addresses will be processed."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "sms",
|
||||
"label": "SMS (Twilio)",
|
||||
"emoji": "📱",
|
||||
"token_var": "TWILIO_ACCOUNT_SID",
|
||||
"setup_instructions": [
|
||||
"1. Create a Twilio account at https://www.twilio.com/",
|
||||
"2. Get your Account SID and Auth Token from the Twilio Console dashboard",
|
||||
"3. Buy or configure a phone number capable of sending SMS",
|
||||
"4. Set up your webhook URL for inbound SMS:",
|
||||
" Twilio Console → Phone Numbers → Active Numbers → your number",
|
||||
" → Messaging → A MESSAGE COMES IN → Webhook → https://your-server:8080/webhooks/twilio",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "TWILIO_ACCOUNT_SID", "prompt": "Twilio Account SID", "password": False,
|
||||
"help": "Found on the Twilio Console dashboard."},
|
||||
{"name": "TWILIO_AUTH_TOKEN", "prompt": "Twilio Auth Token", "password": True,
|
||||
"help": "Found on the Twilio Console dashboard (click to reveal)."},
|
||||
{"name": "TWILIO_PHONE_NUMBER", "prompt": "Twilio phone number (E.164 format, e.g. +15551234567)", "password": False,
|
||||
"help": "The Twilio phone number to send SMS from."},
|
||||
{"name": "SMS_ALLOWED_USERS", "prompt": "Allowed phone numbers (comma-separated, E.164 format)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Only messages from these phone numbers will be processed."},
|
||||
{"name": "SMS_HOME_CHANNEL", "prompt": "Home channel phone number (for cron/notification delivery, or empty)", "password": False,
|
||||
"help": "Phone number to deliver cron job results and notifications to."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "dingtalk",
|
||||
"label": "DingTalk",
|
||||
"emoji": "💬",
|
||||
"token_var": "DINGTALK_CLIENT_ID",
|
||||
"setup_instructions": [
|
||||
"1. Go to https://open-dev.dingtalk.com → Create Application",
|
||||
"2. Under 'Credentials', copy the AppKey (Client ID) and AppSecret (Client Secret)",
|
||||
"3. Enable 'Stream Mode' under the bot settings",
|
||||
"4. Add the bot to a group chat or message it directly",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "DINGTALK_CLIENT_ID", "prompt": "AppKey (Client ID)", "password": False,
|
||||
"help": "The AppKey from your DingTalk application credentials."},
|
||||
{"name": "DINGTALK_CLIENT_SECRET", "prompt": "AppSecret (Client Secret)", "password": True,
|
||||
"help": "The AppSecret from your DingTalk application credentials."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -962,6 +1229,16 @@ def _platform_status(platform: dict) -> str:
|
||||
if any([val, pwd, imap, smtp]):
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if platform.get("key") == "matrix":
|
||||
homeserver = get_env_value("MATRIX_HOMESERVER")
|
||||
password = get_env_value("MATRIX_PASSWORD")
|
||||
if (val or password) and homeserver:
|
||||
e2ee = get_env_value("MATRIX_ENCRYPTION")
|
||||
suffix = " + E2EE" if e2ee and e2ee.lower() in ("true", "1", "yes") else ""
|
||||
return f"configured{suffix}"
|
||||
if val or password or homeserver:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if val:
|
||||
return "configured"
|
||||
return "not configured"
|
||||
@@ -1118,7 +1395,7 @@ def _is_service_running() -> bool:
|
||||
|
||||
if user_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(False) + ["is-active", SERVICE_NAME],
|
||||
_systemctl_cmd(False) + ["is-active", get_service_name()],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
@@ -1126,7 +1403,7 @@ def _is_service_running() -> bool:
|
||||
|
||||
if system_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(True) + ["is-active", SERVICE_NAME],
|
||||
_systemctl_cmd(True) + ["is-active", get_service_name()],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
@@ -1477,14 +1754,17 @@ def gateway_command(args):
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
system = getattr(args, 'system', False)
|
||||
service_configured = False
|
||||
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
service_configured = True
|
||||
try:
|
||||
systemd_restart(system=system)
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
service_configured = True
|
||||
try:
|
||||
launchd_restart()
|
||||
service_available = True
|
||||
@@ -1492,14 +1772,36 @@ def gateway_command(args):
|
||||
pass
|
||||
|
||||
if not service_available:
|
||||
# systemd/launchd restart failed — check if linger is the issue
|
||||
if is_linux():
|
||||
linger_ok, _detail = get_systemd_linger_status()
|
||||
if linger_ok is not True:
|
||||
import getpass
|
||||
_username = getpass.getuser()
|
||||
print()
|
||||
print("⚠ Cannot restart gateway as a service — linger is not enabled.")
|
||||
print(" The gateway user service requires linger to function on headless servers.")
|
||||
print()
|
||||
print(f" Run: sudo loginctl enable-linger {_username}")
|
||||
print()
|
||||
print(" Then restart the gateway:")
|
||||
print(" hermes gateway restart")
|
||||
return
|
||||
|
||||
if service_configured:
|
||||
print()
|
||||
print("✗ Gateway service restart failed.")
|
||||
print(" The service definition exists, but the service manager did not recover it.")
|
||||
print(" Fix the service, then retry: hermes gateway start")
|
||||
sys.exit(1)
|
||||
|
||||
# Manual restart: kill existing processes
|
||||
killed = kill_gateway_processes()
|
||||
if killed:
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
|
||||
import time
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||
|
||||
# Start fresh
|
||||
print("Starting gateway...")
|
||||
run_gateway(verbose=False)
|
||||
|
||||
+845
-61
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
MCP Server Management CLI — ``hermes mcp`` subcommand.
|
||||
|
||||
Implements ``hermes mcp add/remove/list/test/configure`` for interactive
|
||||
MCP server lifecycle management (issue #690 Phase 2).
|
||||
|
||||
Relies on tools/mcp_tool.py for connection/discovery and keeps
|
||||
configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config,
|
||||
save_config,
|
||||
get_env_value,
|
||||
save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── UI Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _success(text: str):
|
||||
print(color(f" ✓ {text}", Colors.GREEN))
|
||||
|
||||
def _warning(text: str):
|
||||
print(color(f" ⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _error(text: str):
|
||||
print(color(f" ✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _confirm(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
try:
|
||||
val = input(color(f" {question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not val:
|
||||
return default
|
||||
return val in ("y", "yes")
|
||||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _get_mcp_servers(config: Optional[dict] = None) -> Dict[str, dict]:
|
||||
"""Return the ``mcp_servers`` dict from config, or empty dict."""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
|
||||
|
||||
def _save_mcp_server(name: str, server_config: dict):
|
||||
"""Add or update a server entry in config.yaml."""
|
||||
config = load_config()
|
||||
config.setdefault("mcp_servers", {})[name] = server_config
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _remove_mcp_server(name: str) -> bool:
|
||||
"""Remove a server from config.yaml. Returns True if it existed."""
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers", {})
|
||||
if name not in servers:
|
||||
return False
|
||||
del servers[name]
|
||||
if not servers:
|
||||
config.pop("mcp_servers", None)
|
||||
save_config(config)
|
||||
return True
|
||||
|
||||
|
||||
def _env_key_for_server(name: str) -> str:
|
||||
"""Convert server name to an env-var key like ``MCP_MYSERVER_API_KEY``."""
|
||||
return f"MCP_{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
# ─── Discovery (temporary connect) ───────────────────────────────────────────
|
||||
|
||||
def _probe_single_server(
|
||||
name: str, config: dict, connect_timeout: float = 30
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Temporarily connect to one MCP server, list its tools, disconnect.
|
||||
|
||||
Returns list of ``(tool_name, description)`` tuples.
|
||||
Raises on connection failure.
|
||||
"""
|
||||
from tools.mcp_tool import (
|
||||
_ensure_mcp_loop,
|
||||
_run_on_mcp_loop,
|
||||
_connect_server,
|
||||
_stop_mcp_loop,
|
||||
)
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
tools_found: List[Tuple[str, str]] = []
|
||||
|
||||
async def _probe():
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config), timeout=connect_timeout
|
||||
)
|
||||
for t in server._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
# Truncate long descriptions for display
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
tools_found.append((t.name, desc))
|
||||
await server.shutdown()
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe(), timeout=connect_timeout + 10)
|
||||
except BaseException as exc:
|
||||
raise _unwrap_exception_group(exc) from None
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return tools_found
|
||||
|
||||
|
||||
def _unwrap_exception_group(exc: BaseException) -> Exception:
|
||||
"""Extract the root-cause exception from anyio TaskGroup wrappers.
|
||||
|
||||
The MCP SDK uses anyio task groups, which wrap errors in
|
||||
``BaseExceptionGroup`` / ``ExceptionGroup``. This makes error
|
||||
messages opaque ("unhandled errors in a TaskGroup"). We unwrap
|
||||
to surface the real cause (e.g. "401 Unauthorized").
|
||||
"""
|
||||
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
|
||||
exc = exc.exceptions[0]
|
||||
# Return a plain Exception so callers can catch normally
|
||||
if isinstance(exc, Exception):
|
||||
return exc
|
||||
return RuntimeError(str(exc))
|
||||
|
||||
|
||||
# ─── hermes mcp add ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_add(args):
|
||||
"""Add a new MCP server with discovery-first tool selection."""
|
||||
name = args.name
|
||||
url = getattr(args, "url", None)
|
||||
command = getattr(args, "command", None)
|
||||
cmd_args = getattr(args, "args", None) or []
|
||||
auth_type = getattr(args, "auth", None)
|
||||
|
||||
# Validate transport
|
||||
if not url and not command:
|
||||
_error("Must specify --url <endpoint> or --command <cmd>")
|
||||
_info("Examples:")
|
||||
_info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"')
|
||||
_info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github')
|
||||
return
|
||||
|
||||
# Check if server already exists
|
||||
existing = _get_mcp_servers()
|
||||
if name in existing:
|
||||
if not _confirm(f"Server '{name}' already exists. Overwrite?", default=False):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
# Build initial config
|
||||
server_config: Dict[str, Any] = {}
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
else:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────
|
||||
|
||||
if url and auth_type == "oauth":
|
||||
print()
|
||||
_info(f"Starting OAuth flow for '{name}'...")
|
||||
oauth_ok = False
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
oauth_auth = build_oauth_auth(name, url)
|
||||
if oauth_auth:
|
||||
server_config["auth"] = "oauth"
|
||||
_success("OAuth configured (tokens will be acquired on first connection)")
|
||||
oauth_ok=True
|
||||
else:
|
||||
_warning("OAuth setup failed — MCP SDK auth module not available")
|
||||
except Exception as exc:
|
||||
_warning(f"OAuth error: {exc}")
|
||||
|
||||
if not oauth_ok:
|
||||
_info("This server may not support OAuth.")
|
||||
if _confirm("Continue without authentication?", default=True):
|
||||
# Don't store auth: oauth — server doesn't support it
|
||||
pass
|
||||
else:
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
elif url:
|
||||
# Prompt for API key / Bearer token for HTTP servers
|
||||
print()
|
||||
_info(f"Connecting to {url}")
|
||||
needs_auth = _confirm("Does this server require authentication?", default=True)
|
||||
if needs_auth:
|
||||
if auth_type == "header" or not auth_type:
|
||||
env_key = _env_key_for_server(name)
|
||||
existing_key = get_env_value(env_key)
|
||||
if existing_key:
|
||||
_success(f"{env_key}: already configured")
|
||||
api_key = existing_key
|
||||
else:
|
||||
api_key = _prompt("API key / Bearer token", password=True)
|
||||
if api_key:
|
||||
save_env_value(env_key, api_key)
|
||||
_success(f"Saved to ~/.hermes/.env as {env_key}")
|
||||
|
||||
# Set header with env var interpolation
|
||||
if api_key or existing_key:
|
||||
server_config["headers"] = {
|
||||
"Authorization": f"Bearer ${{{env_key}}}"
|
||||
}
|
||||
|
||||
# ── Discovery: connect and list tools ─────────────────────────────
|
||||
|
||||
print()
|
||||
print(color(f" Connecting to '{name}'...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
tools = _probe_single_server(name, server_config)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
if _confirm("Save config anyway (you can test later)?", default=False):
|
||||
server_config["enabled"] = False
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config (disabled)")
|
||||
_info("Fix the issue, then: hermes mcp test " + name)
|
||||
return
|
||||
|
||||
if not tools:
|
||||
_warning("Server connected but reported no tools.")
|
||||
if _confirm("Save config anyway?", default=True):
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config")
|
||||
return
|
||||
|
||||
# ── Tool selection ────────────────────────────────────────────────
|
||||
|
||||
print()
|
||||
_success(f"Connected! Found {len(tools)} tool(s) from '{name}':")
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:60] + "..." if len(desc) > 60 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):40s} {short}")
|
||||
print()
|
||||
|
||||
# Ask: enable all, select, or cancel
|
||||
try:
|
||||
choice = input(
|
||||
color(f" Enable all {len(tools)} tools? [Y/n/select]: ", Colors.YELLOW)
|
||||
).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
if choice in ("n", "no"):
|
||||
_info("Cancelled — server not saved.")
|
||||
return
|
||||
|
||||
if choice in ("s", "select"):
|
||||
# Interactive tool selection
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in tools]
|
||||
pre_selected = set(range(len(tools)))
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if not chosen:
|
||||
_info("No tools selected — server not saved.")
|
||||
return
|
||||
|
||||
chosen_names = [tools[i][0] for i in sorted(chosen)]
|
||||
server_config.setdefault("tools", {})["include"] = chosen_names
|
||||
|
||||
tool_count = len(chosen_names)
|
||||
total = len(tools)
|
||||
else:
|
||||
# Enable all (no filter needed — default behaviour)
|
||||
tool_count = len(tools)
|
||||
total = len(tools)
|
||||
|
||||
# ── Save ──────────────────────────────────────────────────────────
|
||||
|
||||
server_config["enabled"] = True
|
||||
_save_mcp_server(name, server_config)
|
||||
|
||||
print()
|
||||
_success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_info("Start a new session to use these tools.")
|
||||
|
||||
|
||||
# ─── hermes mcp remove ───────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_remove(args):
|
||||
"""Remove an MCP server from config."""
|
||||
name = args.name
|
||||
existing = _get_mcp_servers()
|
||||
|
||||
if name not in existing:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
servers = list(existing.keys())
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
if not _confirm(f"Remove server '{name}'?", default=True):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
_remove_mcp_server(name)
|
||||
_success(f"Removed '{name}' from config")
|
||||
|
||||
# Clean up OAuth tokens if they exist
|
||||
try:
|
||||
from tools.mcp_oauth import remove_oauth_tokens
|
||||
remove_oauth_tokens(name)
|
||||
_success("Cleaned up OAuth tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── hermes mcp list ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_list(args=None):
|
||||
"""List all configured MCP servers."""
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if not servers:
|
||||
print()
|
||||
_info("No MCP servers configured.")
|
||||
print()
|
||||
_info("Add one with:")
|
||||
_info(' hermes mcp add <name> --url <endpoint>')
|
||||
_info(' hermes mcp add <name> --command <cmd> --args <args...>')
|
||||
print()
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" MCP Servers:", Colors.CYAN + Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Table header
|
||||
print(f" {'Name':<16} {'Transport':<30} {'Tools':<12} {'Status':<10}")
|
||||
print(f" {'─' * 16} {'─' * 30} {'─' * 12} {'─' * 10}")
|
||||
|
||||
for name, cfg in servers.items():
|
||||
# Transport info
|
||||
if "url" in cfg:
|
||||
url = cfg["url"]
|
||||
# Truncate long URLs
|
||||
if len(url) > 28:
|
||||
url = url[:25] + "..."
|
||||
transport = url
|
||||
elif "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
cmd_args = cfg.get("args", [])
|
||||
if isinstance(cmd_args, list) and cmd_args:
|
||||
transport = f"{cmd} {' '.join(str(a) for a in cmd_args[:2])}"
|
||||
else:
|
||||
transport = cmd
|
||||
if len(transport) > 28:
|
||||
transport = transport[:25] + "..."
|
||||
else:
|
||||
transport = "?"
|
||||
|
||||
# Tool count
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
if include and isinstance(include, list):
|
||||
tools_str = f"{len(include)} selected"
|
||||
elif exclude and isinstance(exclude, list):
|
||||
tools_str = f"-{len(exclude)} excluded"
|
||||
else:
|
||||
tools_str = "all"
|
||||
else:
|
||||
tools_str = "all"
|
||||
|
||||
# Enabled status
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("true", "1", "yes")
|
||||
status = color("✓ enabled", Colors.GREEN) if enabled else color("✗ disabled", Colors.DIM)
|
||||
|
||||
print(f" {name:<16} {transport:<30} {tools_str:<12} {status}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ─── hermes mcp test ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_test(args):
|
||||
"""Test connection to an MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
print()
|
||||
print(color(f" Testing '{name}'...", Colors.CYAN))
|
||||
|
||||
# Show transport info
|
||||
if "url" in cfg:
|
||||
_info(f"Transport: HTTP → {cfg['url']}")
|
||||
else:
|
||||
cmd = cfg.get("command", "?")
|
||||
_info(f"Transport: stdio → {cmd}")
|
||||
|
||||
# Show auth info (masked)
|
||||
auth_type = cfg.get("auth", "")
|
||||
headers = cfg.get("headers", {})
|
||||
if auth_type == "oauth":
|
||||
_info("Auth: OAuth 2.1 PKCE")
|
||||
elif headers:
|
||||
for k, v in headers.items():
|
||||
if isinstance(v, str) and ("key" in k.lower() or "auth" in k.lower()):
|
||||
# Mask the value
|
||||
resolved = _interpolate_value(v)
|
||||
if len(resolved) > 8:
|
||||
masked = resolved[:4] + "***" + resolved[-4:]
|
||||
else:
|
||||
masked = "***"
|
||||
print(f" {k}: {masked}")
|
||||
else:
|
||||
_info("Auth: none")
|
||||
|
||||
# Attempt connection
|
||||
start = time.monotonic()
|
||||
try:
|
||||
tools = _probe_single_server(name, cfg)
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
_error(f"Connection failed ({elapsed_ms:.0f}ms): {exc}")
|
||||
return
|
||||
|
||||
_success(f"Connected ({elapsed_ms:.0f}ms)")
|
||||
_success(f"Tools discovered: {len(tools)}")
|
||||
|
||||
if tools:
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:55] + "..." if len(desc) > 55 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):36s} {short}")
|
||||
print()
|
||||
|
||||
|
||||
def _interpolate_value(value: str) -> str:
|
||||
"""Resolve ``${ENV_VAR}`` references in a string."""
|
||||
def _replace(m):
|
||||
return os.getenv(m.group(1), "")
|
||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||
|
||||
|
||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_configure(args):
|
||||
"""Reconfigure which tools are enabled for an existing MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
|
||||
# Discover all available tools
|
||||
print()
|
||||
print(color(f" Connecting to '{name}' to discover tools...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
all_tools = _probe_single_server(name, cfg)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
return
|
||||
|
||||
if not all_tools:
|
||||
_warning("Server reports no tools.")
|
||||
return
|
||||
|
||||
# Determine which are currently enabled
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
else:
|
||||
include = None
|
||||
exclude = None
|
||||
|
||||
tool_names = [t[0] for t in all_tools]
|
||||
|
||||
if include and isinstance(include, list):
|
||||
include_set = set(include)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn in include_set
|
||||
}
|
||||
elif exclude and isinstance(exclude, list):
|
||||
exclude_set = set(exclude)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn not in exclude_set
|
||||
}
|
||||
else:
|
||||
pre_selected = set(range(len(all_tools)))
|
||||
|
||||
currently = len(pre_selected)
|
||||
total = len(all_tools)
|
||||
_info(f"Currently {currently}/{total} tools enabled for '{name}'.")
|
||||
print()
|
||||
|
||||
# Interactive checklist
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in all_tools]
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_info("No changes made.")
|
||||
return
|
||||
|
||||
# Update config
|
||||
config = load_config()
|
||||
server_entry = config.get("mcp_servers", {}).get(name, {})
|
||||
|
||||
if len(chosen) == total:
|
||||
# All selected → remove include/exclude (register all)
|
||||
server_entry.pop("tools", None)
|
||||
else:
|
||||
chosen_names = [tool_names[i] for i in sorted(chosen)]
|
||||
server_entry.setdefault("tools", {})
|
||||
server_entry["tools"]["include"] = chosen_names
|
||||
server_entry["tools"].pop("exclude", None)
|
||||
|
||||
config.setdefault("mcp_servers", {})[name] = server_entry
|
||||
save_config(config)
|
||||
|
||||
new_count = len(chosen)
|
||||
_success(f"Updated config: {new_count}/{total} tools enabled")
|
||||
_info("Start a new session for changes to take effect.")
|
||||
|
||||
|
||||
# ─── Dispatcher ───────────────────────────────────────────────────────────────
|
||||
|
||||
def mcp_command(args):
|
||||
"""Main dispatcher for ``hermes mcp`` subcommands."""
|
||||
action = getattr(args, "mcp_action", None)
|
||||
|
||||
handlers = {
|
||||
"add": cmd_mcp_add,
|
||||
"remove": cmd_mcp_remove,
|
||||
"rm": cmd_mcp_remove,
|
||||
"list": cmd_mcp_list,
|
||||
"ls": cmd_mcp_list,
|
||||
"test": cmd_mcp_test,
|
||||
"configure": cmd_mcp_configure,
|
||||
"config": cmd_mcp_configure,
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
if handler:
|
||||
handler(args)
|
||||
else:
|
||||
# No subcommand — show list
|
||||
cmd_mcp_list()
|
||||
print(color(" Commands:", Colors.CYAN))
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
_info("hermes mcp list List servers")
|
||||
_info("hermes mcp test <name> Test connection")
|
||||
_info("hermes mcp configure <name> Toggle tools")
|
||||
print()
|
||||
+782
-28
@@ -8,26 +8,47 @@ Add, remove, or reorder entries here — both `hermes setup` and
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, Optional
|
||||
|
||||
COPILOT_BASE_URL = "https://api.githubcopilot.com"
|
||||
COPILOT_MODELS_URL = f"{COPILOT_BASE_URL}/models"
|
||||
COPILOT_EDITOR_VERSION = "vscode/1.104.1"
|
||||
COPILOT_REASONING_EFFORTS_GPT5 = ["minimal", "low", "medium", "high"]
|
||||
COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
|
||||
# Backward-compatible aliases for the earlier GitHub Models-backed Copilot work.
|
||||
GITHUB_MODELS_BASE_URL = COPILOT_BASE_URL
|
||||
GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.4-mini", ""),
|
||||
("xiaomi/mimo-v2-pro", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.7", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20-beta", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
("arcee-ai/trinity-large-preview:free", "free"),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4-nano", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
@@ -45,6 +66,25 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-max",
|
||||
],
|
||||
"copilot-acp": [
|
||||
"copilot-acp",
|
||||
],
|
||||
"copilot": [
|
||||
"gpt-5.4",
|
||||
"gpt-5.4-mini",
|
||||
"gpt-5-mini",
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-opus-4.6",
|
||||
"claude-sonnet-4.6",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-haiku-4.5",
|
||||
"gemini-2.5-pro",
|
||||
"grok-code-fast-1",
|
||||
],
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
@@ -60,11 +100,15 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
@@ -78,17 +122,102 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-haiku-4-5-20251001",
|
||||
],
|
||||
"deepseek": [
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark",
|
||||
"gpt-5.2",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-5.1",
|
||||
"gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-mini",
|
||||
"gpt-5",
|
||||
"gpt-5-codex",
|
||||
"gpt-5-nano",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-1",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4",
|
||||
"claude-haiku-4-5",
|
||||
"claude-3-5-haiku",
|
||||
"gemini-3.1-pro",
|
||||
"gemini-3-pro",
|
||||
"gemini-3-flash",
|
||||
"minimax-m2.7",
|
||||
"minimax-m2.5",
|
||||
"minimax-m2.5-free",
|
||||
"minimax-m2.1",
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2",
|
||||
"qwen3-coder",
|
||||
"big-pickle",
|
||||
],
|
||||
"opencode-go": [
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"minimax-m2.5",
|
||||
],
|
||||
"ai-gateway": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"openai/gpt-5",
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-4.1-mini",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-flash",
|
||||
"deepseek/deepseek-v3.2",
|
||||
],
|
||||
"kilocode": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
"openai/gpt-5.4",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
],
|
||||
"alibaba": [
|
||||
"qwen3.5-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
"qwen-plus-latest",
|
||||
"qwen3.5-flash",
|
||||
"qwen-vl-max",
|
||||
],
|
||||
}
|
||||
|
||||
_PROVIDER_LABELS = {
|
||||
"openrouter": "OpenRouter",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"nous": "Nous Portal",
|
||||
"copilot": "GitHub Copilot",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"anthropic": "Anthropic",
|
||||
"deepseek": "DeepSeek",
|
||||
"opencode-zen": "OpenCode Zen",
|
||||
"opencode-go": "OpenCode Go",
|
||||
"ai-gateway": "AI Gateway",
|
||||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
@@ -97,12 +226,33 @@ _PROVIDER_ALIASES = {
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"github": "copilot",
|
||||
"github-copilot": "copilot",
|
||||
"github-models": "copilot",
|
||||
"github-model": "copilot",
|
||||
"github-copilot-acp": "copilot-acp",
|
||||
"copilot-acp-agent": "copilot-acp",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic",
|
||||
"claude-code": "anthropic",
|
||||
"deep-seek": "deepseek",
|
||||
"opencode": "opencode-zen",
|
||||
"zen": "opencode-zen",
|
||||
"go": "opencode-go",
|
||||
"opencode-go-sub": "opencode-go",
|
||||
"aigateway": "ai-gateway",
|
||||
"vercel": "ai-gateway",
|
||||
"vercel-ai-gateway": "ai-gateway",
|
||||
"kilo": "kilocode",
|
||||
"kilo-code": "kilocode",
|
||||
"kilo-gateway": "kilocode",
|
||||
"dashscope": "alibaba",
|
||||
"aliyun": "alibaba",
|
||||
"qwen": "alibaba",
|
||||
"alibaba-cloud": "alibaba",
|
||||
}
|
||||
|
||||
|
||||
@@ -135,8 +285,10 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic",
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"opencode-zen", "opencode-go",
|
||||
"ai-gateway", "deepseek", "custom",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
@@ -150,9 +302,15 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
from hermes_cli.auth import get_auth_status, has_usable_secret
|
||||
if pid == "custom":
|
||||
custom_base_url = _get_custom_base_url() or os.getenv("OPENAI_BASE_URL", "")
|
||||
has_creds = bool(custom_base_url.strip())
|
||||
elif pid == "openrouter":
|
||||
has_creds = has_usable_secret(os.getenv("OPENROUTER_API_KEY", ""))
|
||||
else:
|
||||
status = get_auth_status(pid)
|
||||
has_creds = bool(status.get("logged_in") or status.get("configured"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
@@ -191,6 +349,19 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def _get_custom_base_url() -> str:
|
||||
"""Get the custom endpoint base_url from config.yaml."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
model_cfg = config.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
return str(model_cfg.get("base_url", "")).strip()
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's model list.
|
||||
|
||||
@@ -212,6 +383,127 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def detect_provider_for_model(
|
||||
model_name: str,
|
||||
current_provider: str,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Auto-detect the best provider for a model name.
|
||||
|
||||
Returns ``(provider_id, model_name)`` — the model name may be remapped
|
||||
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
|
||||
Returns ``None`` when no confident match is found.
|
||||
|
||||
Priority:
|
||||
0. Bare provider name → switch to that provider's default model
|
||||
1. Direct provider with credentials (highest)
|
||||
2. Direct provider without credentials → remap to OpenRouter slug
|
||||
3. OpenRouter catalog match
|
||||
"""
|
||||
name = (model_name or "").strip()
|
||||
if not name:
|
||||
return None
|
||||
|
||||
name_lower = name.lower()
|
||||
|
||||
# --- Step 0: bare provider name typed as model ---
|
||||
# If someone types `/model nous` or `/model anthropic`, treat it as a
|
||||
# provider switch and pick the first model from that provider's catalog.
|
||||
# Skip "custom" and "openrouter" — custom has no model catalog, and
|
||||
# openrouter requires an explicit model name to be useful.
|
||||
resolved_provider = _PROVIDER_ALIASES.get(name_lower, name_lower)
|
||||
if resolved_provider not in {"custom", "openrouter"}:
|
||||
default_models = _PROVIDER_MODELS.get(resolved_provider, [])
|
||||
if (
|
||||
resolved_provider in _PROVIDER_LABELS
|
||||
and default_models
|
||||
and resolved_provider != normalize_provider(current_provider)
|
||||
):
|
||||
return (resolved_provider, default_models[0])
|
||||
|
||||
# Aggregators list other providers' models — never auto-switch TO them
|
||||
_AGGREGATORS = {"nous", "openrouter"}
|
||||
|
||||
# If the model belongs to the current provider's catalog, don't suggest switching
|
||||
current_models = _PROVIDER_MODELS.get(current_provider, [])
|
||||
if any(name_lower == m.lower() for m in current_models):
|
||||
return None
|
||||
|
||||
# --- Step 1: check static provider catalogs for a direct match ---
|
||||
direct_match: Optional[str] = None
|
||||
for pid, models in _PROVIDER_MODELS.items():
|
||||
if pid == current_provider or pid in _AGGREGATORS:
|
||||
continue
|
||||
if any(name_lower == m.lower() for m in models):
|
||||
direct_match = pid
|
||||
break
|
||||
|
||||
if direct_match:
|
||||
# Check if we have credentials for this provider
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY.get(direct_match)
|
||||
if pconfig:
|
||||
import os
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
has_creds = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_creds:
|
||||
return (direct_match, name)
|
||||
|
||||
# No direct creds — try to find this model on OpenRouter instead
|
||||
or_slug = _find_openrouter_slug(name)
|
||||
if or_slug:
|
||||
return ("openrouter", or_slug)
|
||||
# Still return the direct provider — credential resolution will
|
||||
# give a clear error rather than silently using the wrong provider
|
||||
return (direct_match, name)
|
||||
|
||||
# --- Step 2: check OpenRouter catalog ---
|
||||
# First try exact match (handles provider/model format)
|
||||
or_slug = _find_openrouter_slug(name)
|
||||
if or_slug:
|
||||
if current_provider != "openrouter":
|
||||
return ("openrouter", or_slug)
|
||||
# Already on openrouter, just return the resolved slug
|
||||
if or_slug != name:
|
||||
return ("openrouter", or_slug)
|
||||
return None # already on openrouter with matching name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_openrouter_slug(model_name: str) -> Optional[str]:
|
||||
"""Find the full OpenRouter model slug for a bare or partial model name.
|
||||
|
||||
Handles:
|
||||
- Exact match: ``anthropic/claude-opus-4.6`` → as-is
|
||||
- Bare name: ``deepseek-chat`` → ``deepseek/deepseek-chat``
|
||||
- Bare name: ``claude-opus-4.6`` → ``anthropic/claude-opus-4.6``
|
||||
"""
|
||||
name_lower = model_name.strip().lower()
|
||||
if not name_lower:
|
||||
return None
|
||||
|
||||
# Exact match (already has provider/ prefix)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
if name_lower == mid.lower():
|
||||
return mid
|
||||
|
||||
# Try matching just the model part (after the /)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
if "/" in mid:
|
||||
_, model_part = mid.split("/", 1)
|
||||
if name_lower == model_part.lower():
|
||||
return mid
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
@@ -233,6 +525,17 @@ def provider_label(provider: Optional[str]) -> str:
|
||||
return _PROVIDER_LABELS.get(normalized, original or "OpenRouter")
|
||||
|
||||
|
||||
def _resolve_copilot_catalog_api_key() -> str:
|
||||
"""Best-effort GitHub token for fetching the Copilot model catalog."""
|
||||
try:
|
||||
from hermes_cli.auth import resolve_api_key_provider_credentials
|
||||
|
||||
creds = resolve_api_key_provider_credentials("copilot")
|
||||
return str(creds.get("api_key") or "").strip()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""Return the best known model catalog for a provider.
|
||||
|
||||
@@ -246,13 +549,22 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
from hermes_cli.codex_models import get_codex_model_ids
|
||||
|
||||
return get_codex_model_ids()
|
||||
if normalized in {"copilot", "copilot-acp"}:
|
||||
try:
|
||||
live = _fetch_github_models(_resolve_copilot_catalog_api_key())
|
||||
if live:
|
||||
return live
|
||||
except Exception:
|
||||
pass
|
||||
if normalized == "copilot-acp":
|
||||
return list(_PROVIDER_MODELS.get("copilot", []))
|
||||
if normalized == "nous":
|
||||
# Try live Nous Portal /models endpoint
|
||||
try:
|
||||
from hermes_cli.auth import fetch_nous_models, resolve_nous_runtime_credentials
|
||||
creds = resolve_nous_runtime_credentials()
|
||||
if creds:
|
||||
live = fetch_nous_models(creds.get("api_key", ""), creds.get("base_url", ""))
|
||||
live = fetch_nous_models(api_key=creds.get("api_key", ""), inference_base_url=creds.get("base_url", ""))
|
||||
if live:
|
||||
return live
|
||||
except Exception:
|
||||
@@ -261,6 +573,22 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
live = _fetch_anthropic_models()
|
||||
if live:
|
||||
return live
|
||||
if normalized == "ai-gateway":
|
||||
live = _fetch_ai_gateway_models()
|
||||
if live:
|
||||
return live
|
||||
if normalized == "custom":
|
||||
base_url = _get_custom_base_url()
|
||||
if base_url:
|
||||
# Try common API key env vars for custom endpoints
|
||||
api_key = (
|
||||
os.getenv("CUSTOM_API_KEY", "")
|
||||
or os.getenv("OPENAI_API_KEY", "")
|
||||
or os.getenv("OPENROUTER_API_KEY", "")
|
||||
)
|
||||
live = fetch_api_models(api_key, base_url)
|
||||
if live:
|
||||
return live
|
||||
return list(_PROVIDER_MODELS.get(normalized, []))
|
||||
|
||||
|
||||
@@ -308,6 +636,401 @@ def _fetch_anthropic_models(timeout: float = 5.0) -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
|
||||
def _payload_items(payload: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(payload, list):
|
||||
return [item for item in payload if isinstance(item, dict)]
|
||||
if isinstance(payload, dict):
|
||||
data = payload.get("data", [])
|
||||
if isinstance(data, list):
|
||||
return [item for item in data if isinstance(item, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _extract_model_ids(payload: Any) -> list[str]:
|
||||
return [item.get("id", "") for item in _payload_items(payload) if item.get("id")]
|
||||
|
||||
|
||||
def copilot_default_headers() -> dict[str, str]:
|
||||
"""Standard headers for Copilot API requests.
|
||||
|
||||
Includes Openai-Intent and x-initiator headers that opencode and the
|
||||
Copilot CLI send on every request.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.copilot_auth import copilot_request_headers
|
||||
return copilot_request_headers(is_agent_turn=True)
|
||||
except ImportError:
|
||||
return {
|
||||
"Editor-Version": COPILOT_EDITOR_VERSION,
|
||||
"User-Agent": "HermesAgent/1.0",
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"x-initiator": "agent",
|
||||
}
|
||||
|
||||
|
||||
def _copilot_catalog_item_is_text_model(item: dict[str, Any]) -> bool:
|
||||
model_id = str(item.get("id") or "").strip()
|
||||
if not model_id:
|
||||
return False
|
||||
|
||||
if item.get("model_picker_enabled") is False:
|
||||
return False
|
||||
|
||||
capabilities = item.get("capabilities")
|
||||
if isinstance(capabilities, dict):
|
||||
model_type = str(capabilities.get("type") or "").strip().lower()
|
||||
if model_type and model_type != "chat":
|
||||
return False
|
||||
|
||||
supported_endpoints = item.get("supported_endpoints")
|
||||
if isinstance(supported_endpoints, list):
|
||||
normalized_endpoints = {
|
||||
str(endpoint).strip()
|
||||
for endpoint in supported_endpoints
|
||||
if str(endpoint).strip()
|
||||
}
|
||||
if normalized_endpoints and not normalized_endpoints.intersection(
|
||||
{"/chat/completions", "/responses", "/v1/messages"}
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def fetch_github_model_catalog(
|
||||
api_key: Optional[str] = None, timeout: float = 5.0
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""Fetch the live GitHub Copilot model catalog for this account."""
|
||||
attempts: list[dict[str, str]] = []
|
||||
if api_key:
|
||||
attempts.append({
|
||||
**copilot_default_headers(),
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
})
|
||||
attempts.append(copilot_default_headers())
|
||||
|
||||
for headers in attempts:
|
||||
req = urllib.request.Request(COPILOT_MODELS_URL, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
items = _payload_items(data)
|
||||
models: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
for item in items:
|
||||
if not _copilot_catalog_item_is_text_model(item):
|
||||
continue
|
||||
model_id = str(item.get("id") or "").strip()
|
||||
if not model_id or model_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(model_id)
|
||||
models.append(item)
|
||||
if models:
|
||||
return models
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _is_github_models_base_url(base_url: Optional[str]) -> bool:
|
||||
normalized = (base_url or "").strip().rstrip("/").lower()
|
||||
return (
|
||||
normalized.startswith(COPILOT_BASE_URL)
|
||||
or normalized.startswith("https://models.github.ai/inference")
|
||||
)
|
||||
|
||||
|
||||
def _fetch_github_models(api_key: Optional[str] = None, timeout: float = 5.0) -> Optional[list[str]]:
|
||||
catalog = fetch_github_model_catalog(api_key=api_key, timeout=timeout)
|
||||
if not catalog:
|
||||
return None
|
||||
return [item.get("id", "") for item in catalog if item.get("id")]
|
||||
|
||||
|
||||
_COPILOT_MODEL_ALIASES = {
|
||||
"openai/gpt-5": "gpt-5-mini",
|
||||
"openai/gpt-5-chat": "gpt-5-mini",
|
||||
"openai/gpt-5-mini": "gpt-5-mini",
|
||||
"openai/gpt-5-nano": "gpt-5-mini",
|
||||
"openai/gpt-4.1": "gpt-4.1",
|
||||
"openai/gpt-4.1-mini": "gpt-4.1",
|
||||
"openai/gpt-4.1-nano": "gpt-4.1",
|
||||
"openai/gpt-4o": "gpt-4o",
|
||||
"openai/gpt-4o-mini": "gpt-4o-mini",
|
||||
"openai/o1": "gpt-5.2",
|
||||
"openai/o1-mini": "gpt-5-mini",
|
||||
"openai/o1-preview": "gpt-5.2",
|
||||
"openai/o3": "gpt-5.3-codex",
|
||||
"openai/o3-mini": "gpt-5-mini",
|
||||
"openai/o4-mini": "gpt-5-mini",
|
||||
"anthropic/claude-opus-4.6": "claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4.5": "claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
|
||||
def _copilot_catalog_ids(
|
||||
catalog: Optional[list[dict[str, Any]]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> set[str]:
|
||||
if catalog is None and api_key:
|
||||
catalog = fetch_github_model_catalog(api_key=api_key)
|
||||
if not catalog:
|
||||
return set()
|
||||
return {
|
||||
str(item.get("id") or "").strip()
|
||||
for item in catalog
|
||||
if str(item.get("id") or "").strip()
|
||||
}
|
||||
|
||||
|
||||
def normalize_copilot_model_id(
|
||||
model_id: Optional[str],
|
||||
*,
|
||||
catalog: Optional[list[dict[str, Any]]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> str:
|
||||
raw = str(model_id or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
catalog_ids = _copilot_catalog_ids(catalog=catalog, api_key=api_key)
|
||||
alias = _COPILOT_MODEL_ALIASES.get(raw)
|
||||
if alias:
|
||||
return alias
|
||||
|
||||
candidates = [raw]
|
||||
if "/" in raw:
|
||||
candidates.append(raw.split("/", 1)[1].strip())
|
||||
|
||||
if raw.endswith("-mini"):
|
||||
candidates.append(raw[:-5])
|
||||
if raw.endswith("-nano"):
|
||||
candidates.append(raw[:-5])
|
||||
if raw.endswith("-chat"):
|
||||
candidates.append(raw[:-5])
|
||||
|
||||
seen: set[str] = set()
|
||||
for candidate in candidates:
|
||||
if not candidate or candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
if candidate in _COPILOT_MODEL_ALIASES:
|
||||
return _COPILOT_MODEL_ALIASES[candidate]
|
||||
if candidate in catalog_ids:
|
||||
return candidate
|
||||
|
||||
if "/" in raw:
|
||||
return raw.split("/", 1)[1].strip()
|
||||
return raw
|
||||
|
||||
|
||||
def _github_reasoning_efforts_for_model_id(model_id: str) -> list[str]:
|
||||
raw = (model_id or "").strip().lower()
|
||||
if raw.startswith(("openai/o1", "openai/o3", "openai/o4", "o1", "o3", "o4")):
|
||||
return list(COPILOT_REASONING_EFFORTS_O_SERIES)
|
||||
normalized = normalize_copilot_model_id(model_id).lower()
|
||||
if normalized.startswith("gpt-5"):
|
||||
return list(COPILOT_REASONING_EFFORTS_GPT5)
|
||||
return []
|
||||
|
||||
|
||||
def _should_use_copilot_responses_api(model_id: str) -> bool:
|
||||
"""Decide whether a Copilot model should use the Responses API.
|
||||
|
||||
Replicates opencode's ``shouldUseCopilotResponsesApi`` logic:
|
||||
GPT-5+ models use Responses API, except ``gpt-5-mini`` which uses
|
||||
Chat Completions. All non-GPT models (Claude, Gemini, etc.) use
|
||||
Chat Completions.
|
||||
"""
|
||||
import re
|
||||
|
||||
match = re.match(r"^gpt-(\d+)", model_id)
|
||||
if not match:
|
||||
return False
|
||||
major = int(match.group(1))
|
||||
return major >= 5 and not model_id.startswith("gpt-5-mini")
|
||||
|
||||
|
||||
def copilot_model_api_mode(
|
||||
model_id: Optional[str],
|
||||
*,
|
||||
catalog: Optional[list[dict[str, Any]]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Determine the API mode for a Copilot model.
|
||||
|
||||
Uses the model ID pattern (matching opencode's approach) as the
|
||||
primary signal. Falls back to the catalog's ``supported_endpoints``
|
||||
only for models not covered by the pattern check.
|
||||
"""
|
||||
normalized = normalize_copilot_model_id(model_id, catalog=catalog, api_key=api_key)
|
||||
if not normalized:
|
||||
return "chat_completions"
|
||||
|
||||
# Primary: model ID pattern (matches opencode's shouldUseCopilotResponsesApi)
|
||||
if _should_use_copilot_responses_api(normalized):
|
||||
return "codex_responses"
|
||||
|
||||
# Secondary: check catalog for non-GPT-5 models (Claude via /v1/messages, etc.)
|
||||
if catalog is None and api_key:
|
||||
catalog = fetch_github_model_catalog(api_key=api_key)
|
||||
|
||||
if catalog:
|
||||
catalog_entry = next((item for item in catalog if item.get("id") == normalized), None)
|
||||
if isinstance(catalog_entry, dict):
|
||||
supported_endpoints = {
|
||||
str(endpoint).strip()
|
||||
for endpoint in (catalog_entry.get("supported_endpoints") or [])
|
||||
if str(endpoint).strip()
|
||||
}
|
||||
# For non-GPT-5 models, check if they only support messages API
|
||||
if "/v1/messages" in supported_endpoints and "/chat/completions" not in supported_endpoints:
|
||||
return "anthropic_messages"
|
||||
|
||||
return "chat_completions"
|
||||
|
||||
|
||||
def github_model_reasoning_efforts(
|
||||
model_id: Optional[str],
|
||||
*,
|
||||
catalog: Optional[list[dict[str, Any]]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Return supported reasoning-effort levels for a Copilot-visible model."""
|
||||
normalized = normalize_copilot_model_id(model_id, catalog=catalog, api_key=api_key)
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
catalog_entry = None
|
||||
if catalog is not None:
|
||||
catalog_entry = next((item for item in catalog if item.get("id") == normalized), None)
|
||||
elif api_key:
|
||||
fetched_catalog = fetch_github_model_catalog(api_key=api_key)
|
||||
if fetched_catalog:
|
||||
catalog_entry = next((item for item in fetched_catalog if item.get("id") == normalized), None)
|
||||
|
||||
if catalog_entry is not None:
|
||||
capabilities = catalog_entry.get("capabilities")
|
||||
if isinstance(capabilities, dict):
|
||||
supports = capabilities.get("supports")
|
||||
if isinstance(supports, dict):
|
||||
efforts = supports.get("reasoning_effort")
|
||||
if isinstance(efforts, list):
|
||||
normalized_efforts = [
|
||||
str(effort).strip().lower()
|
||||
for effort in efforts
|
||||
if str(effort).strip()
|
||||
]
|
||||
return list(dict.fromkeys(normalized_efforts))
|
||||
return []
|
||||
legacy_capabilities = {
|
||||
str(capability).strip().lower()
|
||||
for capability in catalog_entry.get("capabilities", [])
|
||||
if str(capability).strip()
|
||||
}
|
||||
if "reasoning" not in legacy_capabilities:
|
||||
return []
|
||||
|
||||
return _github_reasoning_efforts_for_model_id(str(model_id or normalized))
|
||||
|
||||
|
||||
def probe_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Probe an OpenAI-compatible ``/models`` endpoint with light URL heuristics."""
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": None,
|
||||
"resolved_base_url": "",
|
||||
"suggested_base_url": None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
|
||||
if _is_github_models_base_url(normalized):
|
||||
models = _fetch_github_models(api_key=api_key, timeout=timeout)
|
||||
return {
|
||||
"models": models,
|
||||
"probed_url": COPILOT_MODELS_URL,
|
||||
"resolved_base_url": COPILOT_BASE_URL,
|
||||
"suggested_base_url": None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
|
||||
if normalized.endswith("/v1"):
|
||||
alternate_base = normalized[:-3].rstrip("/")
|
||||
else:
|
||||
alternate_base = normalized + "/v1"
|
||||
|
||||
candidates: list[tuple[str, bool]] = [(normalized, False)]
|
||||
if alternate_base and alternate_base != normalized:
|
||||
candidates.append((alternate_base, True))
|
||||
|
||||
tried: list[str] = []
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if normalized.startswith(COPILOT_BASE_URL):
|
||||
headers.update(copilot_default_headers())
|
||||
|
||||
for candidate_base, is_fallback in candidates:
|
||||
url = candidate_base.rstrip("/") + "/models"
|
||||
tried.append(url)
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
return {
|
||||
"models": [m.get("id", "") for m in data.get("data", [])],
|
||||
"probed_url": url,
|
||||
"resolved_base_url": candidate_base.rstrip("/"),
|
||||
"suggested_base_url": alternate_base if alternate_base != candidate_base else normalized,
|
||||
"used_fallback": is_fallback,
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models",
|
||||
"resolved_base_url": normalized,
|
||||
"suggested_base_url": alternate_base if alternate_base != normalized else None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
|
||||
|
||||
def _fetch_ai_gateway_models(timeout: float = 5.0) -> Optional[list[str]]:
|
||||
"""Fetch available language models with tool-use from AI Gateway."""
|
||||
api_key = os.getenv("AI_GATEWAY_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
return None
|
||||
base_url = os.getenv("AI_GATEWAY_BASE_URL", "").strip()
|
||||
if not base_url:
|
||||
from hermes_constants import AI_GATEWAY_BASE_URL
|
||||
base_url = AI_GATEWAY_BASE_URL
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {"Authorization": f"Bearer {api_key}"}
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
return [
|
||||
m["id"]
|
||||
for m in data.get("data", [])
|
||||
if m.get("id")
|
||||
and m.get("type") == "language"
|
||||
and "tool-use" in (m.get("tags") or [])
|
||||
]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
@@ -318,22 +1041,7 @@ def fetch_api_models(
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
|
||||
return [m.get("id", "") for m in data.get("data", [])]
|
||||
except Exception:
|
||||
return None
|
||||
return probe_api_models(api_key, base_url, timeout=timeout).get("models")
|
||||
|
||||
|
||||
def validate_requested_model(
|
||||
@@ -359,6 +1067,12 @@ def validate_requested_model(
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter" and base_url and "openrouter.ai" not in base_url:
|
||||
normalized = "custom"
|
||||
requested_for_lookup = requested
|
||||
if normalized == "copilot":
|
||||
requested_for_lookup = normalize_copilot_model_id(
|
||||
requested,
|
||||
api_key=api_key,
|
||||
) or requested
|
||||
|
||||
if not requested:
|
||||
return {
|
||||
@@ -376,20 +1090,60 @@ def validate_requested_model(
|
||||
"message": "Model names cannot contain spaces.",
|
||||
}
|
||||
|
||||
# Custom endpoints can serve any model — skip validation
|
||||
if normalized == "custom":
|
||||
probe = probe_api_models(api_key, base_url)
|
||||
api_models = probe.get("models")
|
||||
if api_models is not None:
|
||||
if requested_for_lookup in set(api_models):
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
message = (
|
||||
f"Note: `{requested}` was not found in this custom endpoint's model listing "
|
||||
f"({probe.get('probed_url')}). It may still work if the server supports hidden or aliased models."
|
||||
f"{suggestion_text}"
|
||||
)
|
||||
if probe.get("used_fallback"):
|
||||
message += (
|
||||
f"\n Endpoint verification succeeded after trying `{probe.get('resolved_base_url')}`. "
|
||||
f"Consider saving that as your base URL."
|
||||
)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
message = (
|
||||
f"Note: could not reach this custom endpoint's model listing at `{probe.get('probed_url')}`. "
|
||||
f"Hermes will still save `{requested}`, but the endpoint should expose `/models` for verification."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
api_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if api_models is not None:
|
||||
if requested in set(api_models):
|
||||
if requested_for_lookup in set(api_models):
|
||||
# API confirmed the model exists
|
||||
return {
|
||||
"accepted": True,
|
||||
|
||||
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
Hermes Plugin System
|
||||
====================
|
||||
|
||||
Discovers, loads, and manages plugins from three sources:
|
||||
|
||||
1. **User plugins** – ``~/.hermes/plugins/<name>/``
|
||||
2. **Project plugins** – ``./.hermes/plugins/<name>/`` (opt-in via
|
||||
``HERMES_ENABLE_PROJECT_PLUGINS``)
|
||||
3. **Pip plugins** – packages that expose the ``hermes_agent.plugins``
|
||||
entry-point group.
|
||||
|
||||
Each directory plugin must contain a ``plugin.yaml`` manifest **and** an
|
||||
``__init__.py`` with a ``register(ctx)`` function.
|
||||
|
||||
Lifecycle hooks
|
||||
---------------
|
||||
Plugins may register callbacks for any of the hooks in ``VALID_HOOKS``.
|
||||
The agent core calls ``invoke_hook(name, **kwargs)`` at the appropriate
|
||||
points.
|
||||
|
||||
Tool registration
|
||||
-----------------
|
||||
``PluginContext.register_tool()`` delegates to ``tools.registry.register()``
|
||||
so plugin-defined tools appear alongside the built-in tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError: # pragma: no cover – yaml is optional at import time
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_HOOKS: Set[str] = {
|
||||
"pre_tool_call",
|
||||
"post_tool_call",
|
||||
"pre_llm_call",
|
||||
"post_llm_call",
|
||||
"on_session_start",
|
||||
"on_session_end",
|
||||
}
|
||||
|
||||
ENTRY_POINTS_GROUP = "hermes_agent.plugins"
|
||||
|
||||
_NS_PARENT = "hermes_plugins"
|
||||
|
||||
|
||||
def _env_enabled(name: str) -> bool:
|
||||
"""Return True when an env var is set to a truthy opt-in value."""
|
||||
return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
"""Parsed representation of a plugin.yaml manifest."""
|
||||
|
||||
name: str
|
||||
version: str = ""
|
||||
description: str = ""
|
||||
author: str = ""
|
||||
requires_env: List[str] = field(default_factory=list)
|
||||
provides_tools: List[str] = field(default_factory=list)
|
||||
provides_hooks: List[str] = field(default_factory=list)
|
||||
source: str = "" # "user", "project", or "entrypoint"
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedPlugin:
|
||||
"""Runtime state for a single loaded plugin."""
|
||||
|
||||
manifest: PluginManifest
|
||||
module: Optional[types.ModuleType] = None
|
||||
tools_registered: List[str] = field(default_factory=list)
|
||||
hooks_registered: List[str] = field(default_factory=list)
|
||||
enabled: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginContext – handed to each plugin's ``register()`` function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PluginContext:
|
||||
"""Facade given to plugins so they can register tools and hooks."""
|
||||
|
||||
def __init__(self, manifest: PluginManifest, manager: "PluginManager"):
|
||||
self.manifest = manifest
|
||||
self._manager = manager
|
||||
|
||||
# -- tool registration --------------------------------------------------
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
toolset: str,
|
||||
schema: dict,
|
||||
handler: Callable,
|
||||
check_fn: Callable | None = None,
|
||||
requires_env: list | None = None,
|
||||
is_async: bool = False,
|
||||
description: str = "",
|
||||
emoji: str = "",
|
||||
) -> None:
|
||||
"""Register a tool in the global registry **and** track it as plugin-provided."""
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name=name,
|
||||
toolset=toolset,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
requires_env=requires_env,
|
||||
is_async=is_async,
|
||||
description=description,
|
||||
emoji=emoji,
|
||||
)
|
||||
self._manager._plugin_tool_names.add(name)
|
||||
logger.debug("Plugin %s registered tool: %s", self.manifest.name, name)
|
||||
|
||||
# -- hook registration --------------------------------------------------
|
||||
|
||||
def register_hook(self, hook_name: str, callback: Callable) -> None:
|
||||
"""Register a lifecycle hook callback.
|
||||
|
||||
Unknown hook names produce a warning but are still stored so
|
||||
forward-compatible plugins don't break.
|
||||
"""
|
||||
if hook_name not in VALID_HOOKS:
|
||||
logger.warning(
|
||||
"Plugin '%s' registered unknown hook '%s' "
|
||||
"(valid: %s)",
|
||||
self.manifest.name,
|
||||
hook_name,
|
||||
", ".join(sorted(VALID_HOOKS)),
|
||||
)
|
||||
self._manager._hooks.setdefault(hook_name, []).append(callback)
|
||||
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PluginManager:
|
||||
"""Central manager that discovers, loads, and invokes plugins."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._plugins: Dict[str, LoadedPlugin] = {}
|
||||
self._hooks: Dict[str, List[Callable]] = {}
|
||||
self._plugin_tool_names: Set[str] = set()
|
||||
self._discovered: bool = False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def discover_and_load(self) -> None:
|
||||
"""Scan all plugin sources and load each plugin found."""
|
||||
if self._discovered:
|
||||
return
|
||||
self._discovered = True
|
||||
|
||||
manifests: List[PluginManifest] = []
|
||||
|
||||
# 1. User plugins (~/.hermes/plugins/)
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
user_dir = Path(hermes_home) / "plugins"
|
||||
manifests.extend(self._scan_directory(user_dir, source="user"))
|
||||
|
||||
# 2. Project plugins (./.hermes/plugins/)
|
||||
if _env_enabled("HERMES_ENABLE_PROJECT_PLUGINS"):
|
||||
project_dir = Path.cwd() / ".hermes" / "plugins"
|
||||
manifests.extend(self._scan_directory(project_dir, source="project"))
|
||||
|
||||
# 3. Pip / entry-point plugins
|
||||
manifests.extend(self._scan_entry_points())
|
||||
|
||||
# Load each manifest
|
||||
for manifest in manifests:
|
||||
self._load_plugin(manifest)
|
||||
|
||||
if manifests:
|
||||
logger.info(
|
||||
"Plugin discovery complete: %d found, %d enabled",
|
||||
len(self._plugins),
|
||||
sum(1 for p in self._plugins.values() if p.enabled),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Directory scanning
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _scan_directory(self, path: Path, source: str) -> List[PluginManifest]:
|
||||
"""Read ``plugin.yaml`` manifests from subdirectories of *path*."""
|
||||
manifests: List[PluginManifest] = []
|
||||
if not path.is_dir():
|
||||
return manifests
|
||||
|
||||
for child in sorted(path.iterdir()):
|
||||
if not child.is_dir():
|
||||
continue
|
||||
manifest_file = child / "plugin.yaml"
|
||||
if not manifest_file.exists():
|
||||
manifest_file = child / "plugin.yml"
|
||||
if not manifest_file.exists():
|
||||
logger.debug("Skipping %s (no plugin.yaml)", child)
|
||||
continue
|
||||
|
||||
try:
|
||||
if yaml is None:
|
||||
logger.warning("PyYAML not installed – cannot load %s", manifest_file)
|
||||
continue
|
||||
data = yaml.safe_load(manifest_file.read_text()) or {}
|
||||
manifest = PluginManifest(
|
||||
name=data.get("name", child.name),
|
||||
version=str(data.get("version", "")),
|
||||
description=data.get("description", ""),
|
||||
author=data.get("author", ""),
|
||||
requires_env=data.get("requires_env", []),
|
||||
provides_tools=data.get("provides_tools", []),
|
||||
provides_hooks=data.get("provides_hooks", []),
|
||||
source=source,
|
||||
path=str(child),
|
||||
)
|
||||
manifests.append(manifest)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse %s: %s", manifest_file, exc)
|
||||
|
||||
return manifests
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Entry-point scanning
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _scan_entry_points(self) -> List[PluginManifest]:
|
||||
"""Check ``importlib.metadata`` for pip-installed plugins."""
|
||||
manifests: List[PluginManifest] = []
|
||||
try:
|
||||
eps = importlib.metadata.entry_points()
|
||||
# Python 3.12+ returns a SelectableGroups; earlier returns dict
|
||||
if hasattr(eps, "select"):
|
||||
group_eps = eps.select(group=ENTRY_POINTS_GROUP)
|
||||
elif isinstance(eps, dict):
|
||||
group_eps = eps.get(ENTRY_POINTS_GROUP, [])
|
||||
else:
|
||||
group_eps = [ep for ep in eps if ep.group == ENTRY_POINTS_GROUP]
|
||||
|
||||
for ep in group_eps:
|
||||
manifest = PluginManifest(
|
||||
name=ep.name,
|
||||
source="entrypoint",
|
||||
path=ep.value,
|
||||
)
|
||||
manifests.append(manifest)
|
||||
except Exception as exc:
|
||||
logger.debug("Entry-point scan failed: %s", exc)
|
||||
|
||||
return manifests
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Loading
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_plugin(self, manifest: PluginManifest) -> None:
|
||||
"""Import a plugin module and call its ``register(ctx)`` function."""
|
||||
loaded = LoadedPlugin(manifest=manifest)
|
||||
|
||||
try:
|
||||
if manifest.source in ("user", "project"):
|
||||
module = self._load_directory_module(manifest)
|
||||
else:
|
||||
module = self._load_entrypoint_module(manifest)
|
||||
|
||||
loaded.module = module
|
||||
|
||||
# Call register()
|
||||
register_fn = getattr(module, "register", None)
|
||||
if register_fn is None:
|
||||
loaded.error = "no register() function"
|
||||
logger.warning("Plugin '%s' has no register() function", manifest.name)
|
||||
else:
|
||||
ctx = PluginContext(manifest, self)
|
||||
register_fn(ctx)
|
||||
loaded.tools_registered = [
|
||||
t for t in self._plugin_tool_names
|
||||
if t not in {
|
||||
n
|
||||
for name, p in self._plugins.items()
|
||||
for n in p.tools_registered
|
||||
}
|
||||
]
|
||||
loaded.hooks_registered = list(
|
||||
{
|
||||
h
|
||||
for h, cbs in self._hooks.items()
|
||||
if cbs # non-empty
|
||||
}
|
||||
- {
|
||||
h
|
||||
for name, p in self._plugins.items()
|
||||
for h in p.hooks_registered
|
||||
}
|
||||
)
|
||||
loaded.enabled = True
|
||||
|
||||
except Exception as exc:
|
||||
loaded.error = str(exc)
|
||||
logger.warning("Failed to load plugin '%s': %s", manifest.name, exc)
|
||||
|
||||
self._plugins[manifest.name] = loaded
|
||||
|
||||
def _load_directory_module(self, manifest: PluginManifest) -> types.ModuleType:
|
||||
"""Import a directory-based plugin as ``hermes_plugins.<name>``."""
|
||||
plugin_dir = Path(manifest.path) # type: ignore[arg-type]
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
raise FileNotFoundError(f"No __init__.py in {plugin_dir}")
|
||||
|
||||
# Ensure the namespace parent package exists
|
||||
if _NS_PARENT not in sys.modules:
|
||||
ns_pkg = types.ModuleType(_NS_PARENT)
|
||||
ns_pkg.__path__ = [] # type: ignore[attr-defined]
|
||||
ns_pkg.__package__ = _NS_PARENT
|
||||
sys.modules[_NS_PARENT] = ns_pkg
|
||||
|
||||
module_name = f"{_NS_PARENT}.{manifest.name.replace('-', '_')}"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
init_file,
|
||||
submodule_search_locations=[str(plugin_dir)],
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot create module spec for {init_file}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.__package__ = module_name
|
||||
module.__path__ = [str(plugin_dir)] # type: ignore[attr-defined]
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _load_entrypoint_module(self, manifest: PluginManifest) -> types.ModuleType:
|
||||
"""Load a pip-installed plugin via its entry-point reference."""
|
||||
eps = importlib.metadata.entry_points()
|
||||
if hasattr(eps, "select"):
|
||||
group_eps = eps.select(group=ENTRY_POINTS_GROUP)
|
||||
elif isinstance(eps, dict):
|
||||
group_eps = eps.get(ENTRY_POINTS_GROUP, [])
|
||||
else:
|
||||
group_eps = [ep for ep in eps if ep.group == ENTRY_POINTS_GROUP]
|
||||
|
||||
for ep in group_eps:
|
||||
if ep.name == manifest.name:
|
||||
return ep.load()
|
||||
|
||||
raise ImportError(
|
||||
f"Entry point '{manifest.name}' not found in group '{ENTRY_POINTS_GROUP}'"
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Hook invocation
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def invoke_hook(self, hook_name: str, **kwargs: Any) -> None:
|
||||
"""Call all registered callbacks for *hook_name*.
|
||||
|
||||
Each callback is wrapped in its own try/except so a misbehaving
|
||||
plugin cannot break the core agent loop.
|
||||
"""
|
||||
callbacks = self._hooks.get(hook_name, [])
|
||||
for cb in callbacks:
|
||||
try:
|
||||
cb(**kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook '%s' callback %s raised: %s",
|
||||
hook_name,
|
||||
getattr(cb, "__name__", repr(cb)),
|
||||
exc,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Introspection
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def list_plugins(self) -> List[Dict[str, Any]]:
|
||||
"""Return a list of info dicts for all discovered plugins."""
|
||||
result: List[Dict[str, Any]] = []
|
||||
for name, loaded in sorted(self._plugins.items()):
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"version": loaded.manifest.version,
|
||||
"description": loaded.manifest.description,
|
||||
"source": loaded.manifest.source,
|
||||
"enabled": loaded.enabled,
|
||||
"tools": len(loaded.tools_registered),
|
||||
"hooks": len(loaded.hooks_registered),
|
||||
"error": loaded.error,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton & convenience functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_plugin_manager: Optional[PluginManager] = None
|
||||
|
||||
|
||||
def get_plugin_manager() -> PluginManager:
|
||||
"""Return (and lazily create) the global PluginManager singleton."""
|
||||
global _plugin_manager
|
||||
if _plugin_manager is None:
|
||||
_plugin_manager = PluginManager()
|
||||
return _plugin_manager
|
||||
|
||||
|
||||
def discover_plugins() -> None:
|
||||
"""Discover and load all plugins (idempotent)."""
|
||||
get_plugin_manager().discover_and_load()
|
||||
|
||||
|
||||
def invoke_hook(hook_name: str, **kwargs: Any) -> None:
|
||||
"""Invoke a lifecycle hook on all loaded plugins."""
|
||||
get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
|
||||
|
||||
def get_plugin_tool_names() -> Set[str]:
|
||||
"""Return the set of tool names registered by plugins."""
|
||||
return get_plugin_manager()._plugin_tool_names
|
||||
|
||||
|
||||
def get_plugin_toolsets() -> List[tuple]:
|
||||
"""Return plugin toolsets as ``(key, label, description)`` tuples.
|
||||
|
||||
Used by the ``hermes tools`` TUI so plugin-provided toolsets appear
|
||||
alongside the built-in ones and can be toggled on/off per platform.
|
||||
"""
|
||||
manager = get_plugin_manager()
|
||||
if not manager._plugin_tool_names:
|
||||
return []
|
||||
|
||||
try:
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# Group plugin tool names by their toolset
|
||||
toolset_tools: Dict[str, List[str]] = {}
|
||||
toolset_plugin: Dict[str, LoadedPlugin] = {}
|
||||
for tool_name in manager._plugin_tool_names:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if not entry:
|
||||
continue
|
||||
ts = entry.toolset
|
||||
toolset_tools.setdefault(ts, []).append(entry.name)
|
||||
|
||||
# Map toolsets back to the plugin that registered them
|
||||
for _name, loaded in manager._plugins.items():
|
||||
for tool_name in loaded.tools_registered:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if entry and entry.toolset in toolset_tools:
|
||||
toolset_plugin.setdefault(entry.toolset, loaded)
|
||||
|
||||
result = []
|
||||
for ts_key in sorted(toolset_tools):
|
||||
plugin = toolset_plugin.get(ts_key)
|
||||
label = f"🔌 {ts_key.replace('_', ' ').title()}"
|
||||
if plugin and plugin.manifest.description:
|
||||
desc = plugin.manifest.description
|
||||
else:
|
||||
desc = ", ".join(sorted(toolset_tools[ts_key]))
|
||||
result.append((ts_key, label, desc))
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,446 @@
|
||||
"""``hermes plugins`` CLI subcommand — install, update, remove, and list plugins.
|
||||
|
||||
Plugins are installed from Git repositories into ``~/.hermes/plugins/``.
|
||||
Supports full URLs and ``owner/repo`` shorthand (resolves to GitHub).
|
||||
|
||||
After install, if the plugin ships an ``after-install.md`` file it is
|
||||
rendered with Rich Markdown. Otherwise a default confirmation is shown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum manifest version this installer understands.
|
||||
# Plugins may declare ``manifest_version: 1`` in plugin.yaml;
|
||||
# future breaking changes to the manifest schema bump this.
|
||||
_SUPPORTED_MANIFEST_VERSION = 1
|
||||
|
||||
|
||||
def _plugins_dir() -> Path:
|
||||
"""Return the user plugins directory, creating it if needed."""
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
plugins = Path(hermes_home) / "plugins"
|
||||
plugins.mkdir(parents=True, exist_ok=True)
|
||||
return plugins
|
||||
|
||||
|
||||
def _sanitize_plugin_name(name: str, plugins_dir: Path) -> Path:
|
||||
"""Validate a plugin name and return the safe target path inside *plugins_dir*.
|
||||
|
||||
Raises ``ValueError`` if the name contains path-traversal sequences or would
|
||||
resolve outside the plugins directory.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("Plugin name must not be empty.")
|
||||
|
||||
# Reject obvious traversal characters
|
||||
for bad in ("/", "\\", ".."):
|
||||
if bad in name:
|
||||
raise ValueError(f"Invalid plugin name '{name}': must not contain '{bad}'.")
|
||||
|
||||
target = (plugins_dir / name).resolve()
|
||||
plugins_resolved = plugins_dir.resolve()
|
||||
|
||||
if (
|
||||
not str(target).startswith(str(plugins_resolved) + os.sep)
|
||||
and target != plugins_resolved
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid plugin name '{name}': resolves outside the plugins directory."
|
||||
)
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def _resolve_git_url(identifier: str) -> str:
|
||||
"""Turn an identifier into a cloneable Git URL.
|
||||
|
||||
Accepted formats:
|
||||
- Full URL: https://github.com/owner/repo.git
|
||||
- Full URL: git@github.com:owner/repo.git
|
||||
- Full URL: ssh://git@github.com/owner/repo.git
|
||||
- Shorthand: owner/repo → https://github.com/owner/repo.git
|
||||
|
||||
NOTE: ``http://`` and ``file://`` schemes are accepted but will trigger a
|
||||
security warning at install time.
|
||||
"""
|
||||
# Already a URL
|
||||
if identifier.startswith(("https://", "http://", "git@", "ssh://", "file://")):
|
||||
return identifier
|
||||
|
||||
# owner/repo shorthand
|
||||
parts = identifier.strip("/").split("/")
|
||||
if len(parts) == 2:
|
||||
owner, repo = parts
|
||||
return f"https://github.com/{owner}/{repo}.git"
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid plugin identifier: '{identifier}'. "
|
||||
"Use a Git URL or owner/repo shorthand."
|
||||
)
|
||||
|
||||
|
||||
def _repo_name_from_url(url: str) -> str:
|
||||
"""Extract the repo name from a Git URL for the plugin directory name."""
|
||||
# Strip trailing .git and slashes
|
||||
name = url.rstrip("/")
|
||||
if name.endswith(".git"):
|
||||
name = name[:-4]
|
||||
# Get last path component
|
||||
name = name.rsplit("/", 1)[-1]
|
||||
# Handle ssh-style urls: git@github.com:owner/repo
|
||||
if ":" in name:
|
||||
name = name.rsplit(":", 1)[-1].rsplit("/", 1)[-1]
|
||||
return name
|
||||
|
||||
|
||||
def _read_manifest(plugin_dir: Path) -> dict:
|
||||
"""Read plugin.yaml and return the parsed dict, or empty dict."""
|
||||
manifest_file = plugin_dir / "plugin.yaml"
|
||||
if not manifest_file.exists():
|
||||
return {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
with open(manifest_file) as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read plugin.yaml in %s: %s", plugin_dir, e)
|
||||
return {}
|
||||
|
||||
|
||||
def _copy_example_files(plugin_dir: Path, console) -> None:
|
||||
"""Copy any .example files to their real names if they don't already exist.
|
||||
|
||||
For example, ``config.yaml.example`` becomes ``config.yaml``.
|
||||
Skips files that already exist to avoid overwriting user config on reinstall.
|
||||
"""
|
||||
for example_file in plugin_dir.glob("*.example"):
|
||||
real_name = example_file.stem # e.g. "config.yaml" from "config.yaml.example"
|
||||
real_path = plugin_dir / real_name
|
||||
if not real_path.exists():
|
||||
try:
|
||||
shutil.copy2(example_file, real_path)
|
||||
console.print(
|
||||
f"[dim] Created {real_name} from {example_file.name}[/dim]"
|
||||
)
|
||||
except OSError as e:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Failed to copy {example_file.name}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _display_after_install(plugin_dir: Path, identifier: str) -> None:
|
||||
"""Show after-install.md if it exists, otherwise a default message."""
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
after_install = plugin_dir / "after-install.md"
|
||||
|
||||
if after_install.exists():
|
||||
content = after_install.read_text(encoding="utf-8")
|
||||
md = Markdown(content)
|
||||
console.print()
|
||||
console.print(Panel(md, border_style="green", expand=False))
|
||||
console.print()
|
||||
else:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[green bold]Plugin installed:[/] {identifier}\n"
|
||||
f"[dim]Location:[/] {plugin_dir}",
|
||||
border_style="green",
|
||||
title="✓ Installed",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
|
||||
def _display_removed(name: str, plugins_dir: Path) -> None:
|
||||
"""Show confirmation after removing a plugin."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
console.print()
|
||||
console.print(f"[red]✗[/red] Plugin [bold]{name}[/bold] removed from {plugins_dir}")
|
||||
console.print()
|
||||
|
||||
|
||||
def _require_installed_plugin(name: str, plugins_dir: Path, console) -> Path:
|
||||
"""Return the plugin path if it exists, or exit with an error listing installed plugins."""
|
||||
target = _sanitize_plugin_name(name, plugins_dir)
|
||||
if not target.exists():
|
||||
installed = ", ".join(d.name for d in plugins_dir.iterdir() if d.is_dir()) or "(none)"
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' not found in {plugins_dir}.\n"
|
||||
f"Installed plugins: {installed}"
|
||||
)
|
||||
sys.exit(1)
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cmd_install(identifier: str, force: bool = False) -> None:
|
||||
"""Install a plugin from a Git URL or owner/repo shorthand."""
|
||||
import tempfile
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
git_url = _resolve_git_url(identifier)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Warn about insecure / local URL schemes
|
||||
if git_url.startswith("http://") or git_url.startswith("file://"):
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] Using insecure/local URL scheme. "
|
||||
"Consider using https:// or git@ for production installs."
|
||||
)
|
||||
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
# Clone into a temp directory first so we can read plugin.yaml for the name
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_target = Path(tmp) / "plugin"
|
||||
console.print(f"[dim]Cloning {git_url}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", git_url, str(tmp_target)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git clone timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Git clone failed:\n{result.stderr.strip()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Read manifest
|
||||
manifest = _read_manifest(tmp_target)
|
||||
plugin_name = manifest.get("name") or _repo_name_from_url(git_url)
|
||||
|
||||
# Sanitize plugin name against path traversal
|
||||
try:
|
||||
target = _sanitize_plugin_name(plugin_name, plugins_dir)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Check manifest_version compatibility
|
||||
mv = manifest.get("manifest_version")
|
||||
if mv is not None:
|
||||
try:
|
||||
mv_int = int(mv)
|
||||
except (ValueError, TypeError):
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' has invalid "
|
||||
f"manifest_version '{mv}' (expected an integer)."
|
||||
)
|
||||
sys.exit(1)
|
||||
if mv_int > _SUPPORTED_MANIFEST_VERSION:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' requires manifest_version "
|
||||
f"{mv}, but this installer only supports up to {_SUPPORTED_MANIFEST_VERSION}.\n"
|
||||
f"Run [bold]hermes update[/bold] to get a newer installer."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if target.exists():
|
||||
if not force:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' already exists at {target}.\n"
|
||||
f"Use [bold]--force[/bold] to remove and reinstall, or "
|
||||
f"[bold]hermes plugins update {plugin_name}[/bold] to pull latest."
|
||||
)
|
||||
sys.exit(1)
|
||||
console.print(f"[dim] Removing existing {plugin_name}...[/dim]")
|
||||
shutil.rmtree(target)
|
||||
|
||||
# Move from temp to final location
|
||||
shutil.move(str(tmp_target), str(target))
|
||||
|
||||
# Validate it looks like a plugin
|
||||
if not (target / "plugin.yaml").exists() and not (target / "__init__.py").exists():
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] {plugin_name} doesn't contain plugin.yaml "
|
||||
f"or __init__.py. It may not be a valid Hermes plugin."
|
||||
)
|
||||
|
||||
# Copy .example files to their real names (e.g. config.yaml.example → config.yaml)
|
||||
_copy_example_files(target, console)
|
||||
|
||||
_display_after_install(target, identifier)
|
||||
|
||||
console.print("[dim]Restart the gateway for the plugin to take effect:[/dim]")
|
||||
console.print("[dim] hermes gateway restart[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
def cmd_update(name: str) -> None:
|
||||
"""Update an installed plugin by pulling latest from its git remote."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not (target / ".git").exists():
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' was not installed from git "
|
||||
f"(no .git directory). Cannot update."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
console.print(f"[dim]Updating {name}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "pull", "--ff-only"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=str(target),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git pull timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(f"[red]Error:[/red] Git pull failed:\n{result.stderr.strip()}")
|
||||
sys.exit(1)
|
||||
|
||||
# Copy any new .example files
|
||||
_copy_example_files(target, console)
|
||||
|
||||
output = result.stdout.strip()
|
||||
if "Already up to date" in output:
|
||||
console.print(
|
||||
f"[green]✓[/green] Plugin [bold]{name}[/bold] is already up to date."
|
||||
)
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Plugin [bold]{name}[/bold] updated.")
|
||||
console.print(f"[dim]{output}[/dim]")
|
||||
|
||||
|
||||
def cmd_remove(name: str) -> None:
|
||||
"""Remove an installed plugin by name."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.rmtree(target)
|
||||
_display_removed(name, plugins_dir)
|
||||
|
||||
|
||||
def cmd_list() -> None:
|
||||
"""List installed plugins."""
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
dirs = sorted(d for d in plugins_dir.iterdir() if d.is_dir())
|
||||
if not dirs:
|
||||
console.print("[dim]No plugins installed.[/dim]")
|
||||
console.print(f"[dim]Install with:[/dim] hermes plugins install owner/repo")
|
||||
return
|
||||
|
||||
table = Table(title="Installed Plugins", show_lines=False)
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Description")
|
||||
table.add_column("Source", style="dim")
|
||||
|
||||
for d in dirs:
|
||||
manifest_file = d / "plugin.yaml"
|
||||
name = d.name
|
||||
version = ""
|
||||
description = ""
|
||||
source = "local"
|
||||
|
||||
if manifest_file.exists() and yaml:
|
||||
try:
|
||||
with open(manifest_file) as f:
|
||||
manifest = yaml.safe_load(f) or {}
|
||||
name = manifest.get("name", d.name)
|
||||
version = manifest.get("version", "")
|
||||
description = manifest.get("description", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if it's a git repo (installed via hermes plugins install)
|
||||
if (d / ".git").exists():
|
||||
source = "git"
|
||||
|
||||
table.add_row(name, str(version), description, source)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
|
||||
def plugins_command(args) -> None:
|
||||
"""Dispatch hermes plugins subcommands."""
|
||||
action = getattr(args, "plugins_action", None)
|
||||
|
||||
if action == "install":
|
||||
cmd_install(args.identifier, force=getattr(args, "force", False))
|
||||
elif action == "update":
|
||||
cmd_update(args.name)
|
||||
elif action in ("remove", "rm", "uninstall"):
|
||||
cmd_remove(args.name)
|
||||
elif action in ("list", "ls") or action is None:
|
||||
cmd_list()
|
||||
else:
|
||||
from rich.console import Console
|
||||
|
||||
Console().print(f"[red]Unknown plugins action: {action}[/red]")
|
||||
sys.exit(1)
|
||||
+166
-34
@@ -14,6 +14,8 @@ from hermes_cli.auth import (
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_external_process_provider_credentials,
|
||||
has_usable_secret,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -23,16 +25,87 @@ def _normalize_custom_provider_name(value: str) -> str:
|
||||
return value.strip().lower().replace(" ", "-")
|
||||
|
||||
|
||||
def _detect_api_mode_for_url(base_url: str) -> Optional[str]:
|
||||
"""Auto-detect api_mode from the resolved base URL.
|
||||
|
||||
Direct api.openai.com endpoints need the Responses API for GPT-5.x
|
||||
tool calls with reasoning (chat/completions returns 400).
|
||||
"""
|
||||
normalized = (base_url or "").strip().lower().rstrip("/")
|
||||
if "api.openai.com" in normalized and "openrouter" not in normalized:
|
||||
return "codex_responses"
|
||||
return None
|
||||
|
||||
|
||||
def _auto_detect_local_model(base_url: str) -> str:
|
||||
"""Query a local server for its model name when only one model is loaded."""
|
||||
if not base_url:
|
||||
return ""
|
||||
try:
|
||||
import requests
|
||||
url = base_url.rstrip("/")
|
||||
if not url.endswith("/v1"):
|
||||
url += "/v1"
|
||||
resp = requests.get(url + "/models", timeout=5)
|
||||
if resp.ok:
|
||||
models = resp.json().get("data", [])
|
||||
if len(models) == 1:
|
||||
model_id = models[0].get("id", "")
|
||||
if model_id:
|
||||
return model_id
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _get_model_config() -> Dict[str, Any]:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
return dict(model_cfg)
|
||||
cfg = dict(model_cfg)
|
||||
default = cfg.get("default", "").strip()
|
||||
base_url = cfg.get("base_url", "").strip()
|
||||
is_local = "localhost" in base_url or "127.0.0.1" in base_url
|
||||
is_fallback = not default or default == "anthropic/claude-opus-4.6"
|
||||
if is_local and is_fallback and base_url:
|
||||
detected = _auto_detect_local_model(base_url)
|
||||
if detected:
|
||||
cfg["default"] = detected
|
||||
return cfg
|
||||
if isinstance(model_cfg, str) and model_cfg.strip():
|
||||
return {"default": model_cfg.strip()}
|
||||
return {}
|
||||
|
||||
|
||||
def _copilot_runtime_api_mode(model_cfg: Dict[str, Any], api_key: str) -> str:
|
||||
configured_mode = _parse_api_mode(model_cfg.get("api_mode"))
|
||||
if configured_mode:
|
||||
return configured_mode
|
||||
|
||||
model_name = str(model_cfg.get("default") or "").strip()
|
||||
if not model_name:
|
||||
return "chat_completions"
|
||||
|
||||
try:
|
||||
from hermes_cli.models import copilot_model_api_mode
|
||||
|
||||
return copilot_model_api_mode(model_name, api_key=api_key)
|
||||
except Exception:
|
||||
return "chat_completions"
|
||||
|
||||
|
||||
_VALID_API_MODES = {"chat_completions", "codex_responses", "anthropic_messages"}
|
||||
|
||||
|
||||
def _parse_api_mode(raw: Any) -> Optional[str]:
|
||||
"""Validate an api_mode value from config. Returns None if invalid."""
|
||||
if isinstance(raw, str):
|
||||
normalized = raw.strip().lower()
|
||||
if normalized in _VALID_API_MODES:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||
"""Resolve provider request from explicit arg, config, then env."""
|
||||
if requested and requested.strip():
|
||||
@@ -86,11 +159,15 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
||||
menu_key = f"custom:{name_norm}"
|
||||
if requested_norm not in {name_norm, menu_key}:
|
||||
continue
|
||||
return {
|
||||
result = {
|
||||
"name": name.strip(),
|
||||
"base_url": base_url.strip(),
|
||||
"api_key": str(entry.get("api_key", "") or "").strip(),
|
||||
}
|
||||
api_mode = _parse_api_mode(entry.get("api_mode"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
@@ -112,16 +189,19 @@ def _resolve_named_custom_runtime(
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
api_key = (
|
||||
(explicit_api_key or "").strip()
|
||||
or custom_provider.get("api_key", "")
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
or os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
)
|
||||
api_key_candidates = [
|
||||
(explicit_api_key or "").strip(),
|
||||
str(custom_provider.get("api_key", "") or "").strip(),
|
||||
os.getenv("OPENAI_API_KEY", "").strip(),
|
||||
os.getenv("OPENROUTER_API_KEY", "").strip(),
|
||||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
@@ -137,6 +217,12 @@ def _resolve_openrouter_runtime(
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = model_cfg.get("base_url") if isinstance(model_cfg.get("base_url"), str) else ""
|
||||
cfg_provider = model_cfg.get("provider") if isinstance(model_cfg.get("provider"), str) else ""
|
||||
cfg_api_key = ""
|
||||
for k in ("api_key", "api"):
|
||||
v = model_cfg.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
cfg_api_key = v.strip()
|
||||
break
|
||||
requested_norm = (requested_provider or "").strip().lower()
|
||||
cfg_provider = cfg_provider.strip().lower()
|
||||
|
||||
@@ -144,26 +230,24 @@ def _resolve_openrouter_runtime(
|
||||
env_openrouter_base_url = os.getenv("OPENROUTER_BASE_URL", "").strip()
|
||||
|
||||
use_config_base_url = False
|
||||
if cfg_base_url.strip() and not explicit_base_url and not env_openai_base_url:
|
||||
if cfg_base_url.strip() and not explicit_base_url:
|
||||
if requested_norm == "auto":
|
||||
if not cfg_provider or cfg_provider == "auto":
|
||||
use_config_base_url = True
|
||||
elif requested_norm == "custom":
|
||||
# Persisted custom endpoints store their base URL in config.yaml.
|
||||
# If OPENAI_BASE_URL is not currently set in the environment, keep
|
||||
# honoring that saved endpoint instead of falling back to OpenRouter.
|
||||
if cfg_provider == "custom":
|
||||
if (not cfg_provider or cfg_provider == "auto") and not env_openai_base_url:
|
||||
use_config_base_url = True
|
||||
elif requested_norm == "custom" and cfg_provider == "custom":
|
||||
# provider: custom — use base_url from config (Fixes #1760).
|
||||
use_config_base_url = True
|
||||
|
||||
# When the user explicitly requested the openrouter provider, skip
|
||||
# OPENAI_BASE_URL — it typically points to a custom / non-OpenRouter
|
||||
# endpoint and would prevent switching back to OpenRouter (#874).
|
||||
skip_openai_base = requested_norm == "openrouter"
|
||||
|
||||
# For custom, prefer config base_url over env so config.yaml is honored (#1760).
|
||||
base_url = (
|
||||
(explicit_base_url or "").strip()
|
||||
or ("" if skip_openai_base else env_openai_base_url)
|
||||
or (cfg_base_url.strip() if use_config_base_url else "")
|
||||
or ("" if skip_openai_base else env_openai_base_url)
|
||||
or env_openrouter_base_url
|
||||
or OPENROUTER_BASE_URL
|
||||
).rstrip("/")
|
||||
@@ -175,25 +259,31 @@ def _resolve_openrouter_runtime(
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
]
|
||||
else:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
# Custom endpoint: use api_key from config when using config base_url (#1760).
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
(cfg_api_key if use_config_base_url else ""),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
]
|
||||
api_key = next(
|
||||
(str(candidate or "").strip() for candidate in api_key_candidates if has_usable_secret(candidate)),
|
||||
"",
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"api_mode": _parse_api_mode(model_cfg.get("api_mode"))
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"source": source,
|
||||
@@ -251,6 +341,19 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
if provider == "copilot-acp":
|
||||
creds = resolve_external_process_provider_credentials(provider)
|
||||
return {
|
||||
"provider": "copilot-acp",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"command": creds.get("command", ""),
|
||||
"args": list(creds.get("args") or []),
|
||||
"source": creds.get("source", "process"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# Anthropic (native Messages API)
|
||||
if provider == "anthropic":
|
||||
from agent.anthropic_adapter import resolve_anthropic_token
|
||||
@@ -260,10 +363,19 @@ def resolve_runtime_provider(
|
||||
"No Anthropic credentials found. Set ANTHROPIC_TOKEN or ANTHROPIC_API_KEY, "
|
||||
"run 'claude setup-token', or authenticate with 'claude /login'."
|
||||
)
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
model_cfg = _get_model_config()
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
cfg_base_url = ""
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
base_url = cfg_base_url or "https://api.anthropic.com"
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"base_url": base_url,
|
||||
"api_key": token,
|
||||
"source": "env",
|
||||
"requested_provider": requested_provider,
|
||||
@@ -273,10 +385,30 @@ def resolve_runtime_provider(
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
model_cfg = _get_model_config()
|
||||
base_url = creds.get("base_url", "").rstrip("/")
|
||||
api_mode = "chat_completions"
|
||||
if provider == "copilot":
|
||||
api_mode = _copilot_runtime_api_mode(model_cfg, creds.get("api_key", ""))
|
||||
else:
|
||||
# Check explicit api_mode from model config first
|
||||
configured_mode = _parse_api_mode(model_cfg.get("api_mode"))
|
||||
if configured_mode:
|
||||
api_mode = configured_mode
|
||||
# Auto-detect Anthropic-compatible endpoints by URL convention
|
||||
# (e.g. https://api.minimax.io/anthropic, https://dashscope.../anthropic)
|
||||
elif base_url.rstrip("/").endswith("/anthropic"):
|
||||
api_mode = "anthropic_messages"
|
||||
# MiniMax providers always use Anthropic Messages API.
|
||||
# Auto-correct stale /v1 URLs (from old .env or config) to /anthropic.
|
||||
elif provider in ("minimax", "minimax-cn"):
|
||||
api_mode = "anthropic_messages"
|
||||
if base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/")[:-3] + "/anthropic"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_mode": api_mode,
|
||||
"base_url": base_url,
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
|
||||
+852
-215
File diff suppressed because it is too large
Load Diff
+30
-17
@@ -304,7 +304,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
console: Optional[Console] = None, skip_confirm: bool = False) -> None:
|
||||
"""Fetch, quarantine, scan, confirm, and install a skill."""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router, ensure_hub_dirs,
|
||||
@@ -378,7 +378,8 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
c.print(Panel("\n".join(metadata_lines), title="Upstream Metadata", border_style="blue"))
|
||||
|
||||
# Confirm with user — show appropriate warning based on source
|
||||
if not force:
|
||||
# skip_confirm bypasses the prompt (needed in TUI mode where input() hangs)
|
||||
if not force and not skip_confirm:
|
||||
c.print()
|
||||
if bundle.source == "official":
|
||||
c.print(Panel(
|
||||
@@ -454,6 +455,8 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
# Show first 50 lines as preview
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
@@ -598,20 +601,23 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N
|
||||
c.print()
|
||||
|
||||
|
||||
def do_uninstall(name: str, console: Optional[Console] = None) -> None:
|
||||
def do_uninstall(name: str, console: Optional[Console] = None,
|
||||
skip_confirm: bool = False) -> None:
|
||||
"""Remove a hub-installed skill with confirmation."""
|
||||
from tools.skills_hub import uninstall_skill
|
||||
|
||||
c = console or _console
|
||||
|
||||
c.print(f"\n[bold]Uninstall '{name}'?[/]")
|
||||
try:
|
||||
answer = input("Confirm [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
if answer not in ("y", "yes"):
|
||||
c.print("[dim]Cancelled.[/]\n")
|
||||
return
|
||||
# skip_confirm bypasses the prompt (needed in TUI mode where input() hangs)
|
||||
if not skip_confirm:
|
||||
c.print(f"\n[bold]Uninstall '{name}'?[/]")
|
||||
try:
|
||||
answer = input("Confirm [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
if answer not in ("y", "yes"):
|
||||
c.print("[dim]Cancelled.[/]\n")
|
||||
return
|
||||
|
||||
success, msg = uninstall_skill(name)
|
||||
if success:
|
||||
@@ -636,7 +642,8 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No
|
||||
table.add_column("Repo", style="bold cyan")
|
||||
table.add_column("Path", style="dim")
|
||||
for t in taps:
|
||||
table.add_row(t["repo"], t.get("path", "skills/"))
|
||||
label = t.get("repo") or t.get("name") or t.get("path", "unknown")
|
||||
table.add_row(label, t.get("path", "skills/"))
|
||||
c.print(table)
|
||||
c.print()
|
||||
|
||||
@@ -923,7 +930,8 @@ def skills_command(args) -> None:
|
||||
elif action == "search":
|
||||
do_search(args.query, source=args.source, limit=args.limit)
|
||||
elif action == "install":
|
||||
do_install(args.identifier, category=args.category, force=args.force)
|
||||
do_install(args.identifier, category=args.category, force=args.force,
|
||||
skip_confirm=getattr(args, "yes", False))
|
||||
elif action == "inspect":
|
||||
do_inspect(args.identifier)
|
||||
elif action == "list":
|
||||
@@ -1054,11 +1062,15 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
return
|
||||
identifier = args[0]
|
||||
category = ""
|
||||
force = any(flag in args for flag in ("--force", "--yes", "-y"))
|
||||
# --yes / -y bypasses confirmation prompt (needed in TUI mode)
|
||||
# --force handles reinstall override
|
||||
skip_confirm = any(flag in args for flag in ("--yes", "-y"))
|
||||
force = "--force" in args
|
||||
for i, a in enumerate(args):
|
||||
if a == "--category" and i + 1 < len(args):
|
||||
category = args[i + 1]
|
||||
do_install(identifier, category=category, force=force, console=c)
|
||||
do_install(identifier, category=category, force=force,
|
||||
skip_confirm=skip_confirm, console=c)
|
||||
|
||||
elif action == "inspect":
|
||||
if not args:
|
||||
@@ -1088,9 +1100,10 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
elif action == "uninstall":
|
||||
if not args:
|
||||
c.print("[bold red]Usage:[/] /skills uninstall <name>\n")
|
||||
c.print("[bold red]Usage:[/] /skills uninstall <name> [--yes]\n")
|
||||
return
|
||||
do_uninstall(args[0], console=c)
|
||||
skip_confirm = any(flag in args for flag in ("--yes", "-y"))
|
||||
do_uninstall(args[0], console=c, skip_confirm=skip_confirm)
|
||||
|
||||
elif action == "publish":
|
||||
if not args:
|
||||
|
||||
@@ -60,6 +60,12 @@ All fields are optional. Missing values inherit from the ``default`` skin.
|
||||
# Tool prefix: character for tool output lines (default: ┊)
|
||||
tool_prefix: "┊"
|
||||
|
||||
# Tool emojis: override the default emoji for any tool (used in spinners & progress)
|
||||
tool_emojis:
|
||||
terminal: "⚔" # Override terminal tool emoji
|
||||
web_search: "🔮" # Override web_search tool emoji
|
||||
# Any tool not listed here uses its registry default
|
||||
|
||||
USAGE
|
||||
=====
|
||||
|
||||
@@ -111,6 +117,7 @@ class SkinConfig:
|
||||
spinner: Dict[str, Any] = field(default_factory=dict)
|
||||
branding: Dict[str, str] = field(default_factory=dict)
|
||||
tool_prefix: str = "┊"
|
||||
tool_emojis: Dict[str, str] = field(default_factory=dict) # per-tool emoji overrides
|
||||
banner_logo: str = "" # Rich-markup ASCII art logo (replaces HERMES_AGENT_LOGO)
|
||||
banner_hero: str = "" # Rich-markup hero art (replaces HERMES_CADUCEUS)
|
||||
|
||||
@@ -344,12 +351,12 @@ _BUILTIN_SKINS: Dict[str, Dict[str, Any]] = {
|
||||
"help_header": "(Ψ) Available Commands",
|
||||
},
|
||||
"tool_prefix": "│",
|
||||
"banner_logo": """[bold #B8E8FF]██████╗ ██████╗ ███████╗██╗██████╗ ███████╗ ██████╗ ███╗ ██╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/]
|
||||
[bold #97D6FF]██╔══██╗██╔═══██╗██╔════╝██║██╔══██╗██╔════╝██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝[/]
|
||||
[#75C1F6]██████╔╝██║ ██║███████╗██║██║ ██║█████╗ ██║ ██║██╔██╗ ██║█████╗███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║[/]
|
||||
[#4FA2E0]██╔═══╝ ██║ ██║╚════██║██║██║ ██║██╔══╝ ██║ ██║██║╚██╗██║╚════╝██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║[/]
|
||||
[#2E7CC7]██║ ╚██████╔╝███████║██║██████╔╝███████╗╚██████╔╝██║ ╚████║ ██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║[/]
|
||||
[#1B4F95]╚═╝ ╚═════╝ ╚══════╝╚═╝╚═════╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝[/]""",
|
||||
"banner_logo": """[bold #B8E8FF]██████╗ ██████╗ ███████╗███████╗██╗██████╗ ██████╗ ███╗ ██╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/]
|
||||
[bold #97D6FF]██╔══██╗██╔═══██╗██╔════╝██╔════╝██║██╔══██╗██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝[/]
|
||||
[#75C1F6]██████╔╝██║ ██║███████╗█████╗ ██║██║ ██║██║ ██║██╔██╗ ██║█████╗███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║[/]
|
||||
[#4FA2E0]██╔═══╝ ██║ ██║╚════██║██╔══╝ ██║██║ ██║██║ ██║██║╚██╗██║╚════╝██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║[/]
|
||||
[#2E7CC7]██║ ╚██████╔╝███████║███████╗██║██████╔╝╚██████╔╝██║ ╚████║ ██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║[/]
|
||||
[#1B4F95]╚═╝ ╚═════╝ ╚══════╝╚══════╝╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝[/]""",
|
||||
"banner_hero": """[#2A6FB9]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
[#5DB8F5]⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⣾⣿⣷⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
[#5DB8F5]⠀⠀⠀⠀⠀⠀⠀⢠⣿⠏⠀Ψ⠀⠹⣿⡄⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
@@ -541,6 +548,7 @@ def _build_skin_config(data: Dict[str, Any]) -> SkinConfig:
|
||||
spinner=spinner,
|
||||
branding=branding,
|
||||
tool_prefix=data.get("tool_prefix", default.get("tool_prefix", "┊")),
|
||||
tool_emojis=data.get("tool_emojis", {}),
|
||||
banner_logo=data.get("banner_logo", ""),
|
||||
banner_hero=data.get("banner_hero", ""),
|
||||
)
|
||||
|
||||
@@ -120,6 +120,7 @@ def show_status(args):
|
||||
"MiniMax": "MINIMAX_API_KEY",
|
||||
"MiniMax-CN": "MINIMAX_CN_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Tavily": "TAVILY_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
@@ -252,6 +253,7 @@ def show_status(args):
|
||||
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
|
||||
"Slack": ("SLACK_BOT_TOKEN", None),
|
||||
"Email": ("EMAIL_ADDRESS", "EMAIL_HOME_ADDRESS"),
|
||||
"SMS": ("TWILIO_ACCOUNT_SID", "SMS_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
@@ -275,8 +277,13 @@ def show_status(args):
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
except Exception:
|
||||
_gw_svc = "hermes-gateway"
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
["systemctl", "--user", "is-active", _gw_svc],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
+462
-56
@@ -101,6 +101,30 @@ CONFIGURABLE_TOOLSETS = [
|
||||
# but the setup checklist won't pre-select them for first-time users.
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
|
||||
def _get_effective_configurable_toolsets():
|
||||
"""Return CONFIGURABLE_TOOLSETS + any plugin-provided toolsets.
|
||||
|
||||
Plugin toolsets are appended at the end so they appear after the
|
||||
built-in toolsets in the TUI checklist.
|
||||
"""
|
||||
result = list(CONFIGURABLE_TOOLSETS)
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
result.extend(get_plugin_toolsets())
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _get_plugin_toolset_keys() -> set:
|
||||
"""Return the set of toolset keys provided by plugins."""
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
return {ts_key for ts_key, _, _ in get_plugin_toolsets()}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
@@ -110,6 +134,7 @@ PLATFORMS = {
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
}
|
||||
|
||||
|
||||
@@ -150,19 +175,37 @@ TOOL_CATEGORIES = {
|
||||
"web": {
|
||||
"name": "Web Search & Extract",
|
||||
"setup_title": "Select Search Provider",
|
||||
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need Firecrawl.",
|
||||
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need a premium provider.",
|
||||
"icon": "🔍",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Recommended - hosted service",
|
||||
"tag": "Hosted service - search, extract, and crawl",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Parallel",
|
||||
"tag": "AI-native search and extract",
|
||||
"web_backend": "parallel",
|
||||
"env_vars": [
|
||||
{"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Tavily",
|
||||
"tag": "AI-native search, extract, and crawl",
|
||||
"web_backend": "tavily",
|
||||
"env_vars": [
|
||||
{"key": "TAVILY_API_KEY", "prompt": "Tavily API key", "url": "https://app.tavily.com/home"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
],
|
||||
@@ -190,6 +233,7 @@ TOOL_CATEGORIES = {
|
||||
"name": "Local Browser",
|
||||
"tag": "Free headless Chromium (no API key needed)",
|
||||
"env_vars": [],
|
||||
"browser_provider": None,
|
||||
"post_setup": "browserbase", # Same npm install for agent-browser
|
||||
},
|
||||
{
|
||||
@@ -199,6 +243,16 @@ TOOL_CATEGORIES = {
|
||||
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
|
||||
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
|
||||
],
|
||||
"browser_provider": "browserbase",
|
||||
"post_setup": "browserbase",
|
||||
},
|
||||
{
|
||||
"name": "Browser Use",
|
||||
"tag": "Cloud browser with remote execution",
|
||||
"env_vars": [
|
||||
{"key": "BROWSER_USE_API_KEY", "prompt": "Browser Use API key", "url": "https://browser-use.com"},
|
||||
],
|
||||
"browser_provider": "browser-use",
|
||||
"post_setup": "browserbase",
|
||||
},
|
||||
],
|
||||
@@ -337,18 +391,46 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
# Resolve to individual tool names, then map back to which
|
||||
# configurable toolsets are covered
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
# Map individual tool names back to configurable toolset keys
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
# If the saved list contains any configurable keys directly, the user
|
||||
# has explicitly configured this platform — use direct membership.
|
||||
# This avoids the subset-inference bug where composite toolsets like
|
||||
# "hermes-cli" (which include all _HERMES_CORE_TOOLS) cause disabled
|
||||
# toolsets to re-appear as enabled.
|
||||
has_explicit_config = any(ts in configurable_keys for ts in toolset_names)
|
||||
|
||||
if has_explicit_config:
|
||||
enabled_toolsets = {ts for ts in toolset_names if ts in configurable_keys}
|
||||
else:
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
|
||||
# Plugin toolsets: enabled by default unless explicitly disabled.
|
||||
# A plugin toolset is "known" for a platform once `hermes tools`
|
||||
# has been saved for that platform (tracked via known_plugin_toolsets).
|
||||
# Unknown plugins default to enabled; known-but-absent = disabled.
|
||||
plugin_ts_keys = _get_plugin_toolset_keys()
|
||||
if plugin_ts_keys:
|
||||
known_map = config.get("known_plugin_toolsets", {})
|
||||
known_for_platform = set(known_map.get(platform, []))
|
||||
for pts in plugin_ts_keys:
|
||||
if pts in toolset_names:
|
||||
# Explicitly listed in config — enabled
|
||||
enabled_toolsets.add(pts)
|
||||
elif pts not in known_for_platform:
|
||||
# New plugin not yet seen by hermes tools — default enabled
|
||||
enabled_toolsets.add(pts)
|
||||
# else: known but not in config = user disabled it
|
||||
|
||||
return enabled_toolsets
|
||||
|
||||
@@ -361,22 +443,37 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
|
||||
"""
|
||||
config.setdefault("platform_toolsets", {})
|
||||
|
||||
# Get the set of all configurable toolset keys
|
||||
# Get the set of all configurable toolset keys (built-in + plugin)
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
plugin_keys = _get_plugin_toolset_keys()
|
||||
configurable_keys |= plugin_keys
|
||||
|
||||
# Also exclude platform default toolsets (hermes-cli, hermes-telegram, etc.)
|
||||
# These are "super" toolsets that resolve to ALL tools, so preserving them
|
||||
# would silently override the user's unchecked selections on the next read.
|
||||
platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve any entries that are NOT configurable toolsets (i.e. MCP server names)
|
||||
# Preserve any entries that are NOT configurable toolsets and NOT platform
|
||||
# defaults (i.e. only MCP server names should be preserved)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys
|
||||
if entry not in configurable_keys and entry not in platform_default_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys | preserved_entries)
|
||||
|
||||
# Track which plugin toolsets are "known" for this platform so we can
|
||||
# distinguish "new plugin, default enabled" from "user disabled it".
|
||||
if plugin_keys:
|
||||
config.setdefault("known_plugin_toolsets", {})
|
||||
config["known_plugin_toolsets"][platform] = sorted(plugin_keys)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
@@ -494,15 +591,17 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, ts_desc in effective:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected = {
|
||||
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
|
||||
i for i, (ts_key, _, _) in enumerate(effective)
|
||||
if ts_key in enabled
|
||||
}
|
||||
|
||||
@@ -512,7 +611,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
)
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in chosen}
|
||||
return {effective[i][0] for i in chosen}
|
||||
|
||||
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
@@ -575,10 +674,10 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
if _is_provider_active(p, config):
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
@@ -587,15 +686,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
provider_choices.append("Skip — keep defaults / configure later")
|
||||
|
||||
# Detect current provider as default
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
default_idx = _detect_active_provider_index(providers, config)
|
||||
|
||||
provider_idx = _prompt_choice(f" {title}:", provider_choices, default_idx)
|
||||
|
||||
@@ -607,6 +698,31 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
_configure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _is_provider_active(provider: dict, config: dict) -> bool:
|
||||
"""Check if a provider entry matches the currently active config."""
|
||||
if provider.get("tts_provider"):
|
||||
return config.get("tts", {}).get("provider") == provider["tts_provider"]
|
||||
if "browser_provider" in provider:
|
||||
current = config.get("browser", {}).get("cloud_provider")
|
||||
return provider["browser_provider"] == current
|
||||
if provider.get("web_backend"):
|
||||
current = config.get("web", {}).get("backend")
|
||||
return current == provider["web_backend"]
|
||||
return False
|
||||
|
||||
|
||||
def _detect_active_provider_index(providers: list, config: dict) -> int:
|
||||
"""Return the index of the currently active provider, or 0."""
|
||||
for i, p in enumerate(providers):
|
||||
if _is_provider_active(p, config):
|
||||
return i
|
||||
# Fallback: env vars present → likely configured
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
return i
|
||||
return 0
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
@@ -615,6 +731,20 @@ def _configure_provider(provider: dict, config: dict):
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
|
||||
# Set browser cloud provider in config if applicable
|
||||
if "browser_provider" in provider:
|
||||
bp = provider["browser_provider"]
|
||||
if bp:
|
||||
config.setdefault("browser", {})["cloud_provider"] = bp
|
||||
_print_success(f" Browser cloud provider set to: {bp}")
|
||||
else:
|
||||
config.get("browser", {}).pop("cloud_provider", None)
|
||||
|
||||
# Set web search backend in config if applicable
|
||||
if provider.get("web_backend"):
|
||||
config.setdefault("web", {})["backend"] = provider["web_backend"]
|
||||
_print_success(f" Web backend set to: {provider['web_backend']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
@@ -696,7 +826,7 @@ def _configure_simple_requirements(ts_key: str):
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
@@ -715,7 +845,7 @@ def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, _ in _get_effective_configurable_toolsets():
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
@@ -767,7 +897,7 @@ def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
if _is_provider_active(p, config):
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = ""
|
||||
@@ -775,15 +905,7 @@ def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
default_idx = _detect_active_provider_index(providers, config)
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_reconfigure_provider(providers[provider_idx], config)
|
||||
@@ -797,6 +919,20 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
_print_success(f" TTS provider set to: {provider['tts_provider']}")
|
||||
|
||||
if "browser_provider" in provider:
|
||||
bp = provider["browser_provider"]
|
||||
if bp:
|
||||
config.setdefault("browser", {})["cloud_provider"] = bp
|
||||
_print_success(f" Browser cloud provider set to: {bp}")
|
||||
else:
|
||||
config.get("browser", {}).pop("cloud_provider", None)
|
||||
_print_success(f" Browser set to local mode")
|
||||
|
||||
# Set web search backend in config if applicable
|
||||
if provider.get("web_backend"):
|
||||
config.setdefault("web", {})["backend"] = provider["web_backend"]
|
||||
_print_success(f" Web backend set to: {provider['web_backend']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
@@ -823,7 +959,7 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
@@ -862,7 +998,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Non-interactive summary mode for CLI usage
|
||||
if getattr(args, "summary", False):
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
print(color("⚕ Tool Summary", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
summary = _platform_toolset_summary(config, enabled_platforms)
|
||||
@@ -873,7 +1009,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print(color(f" {pinfo['label']}", Colors.BOLD) + color(f" ({count}/{total})", Colors.DIM))
|
||||
if enabled:
|
||||
for ts_key in sorted(enabled):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" ✓ {label}", Colors.GREEN))
|
||||
else:
|
||||
print(color(" (none enabled)", Colors.DIM))
|
||||
@@ -900,11 +1036,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
removed = current_enabled - new_enabled
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Walk through ALL selected tools that have provider options or
|
||||
@@ -920,7 +1056,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
print(color(f" Configuring {len(to_configure)} tool(s):", Colors.YELLOW))
|
||||
for ts_key in to_configure:
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" • {label}", Colors.DIM))
|
||||
print(color(" You can skip any tool you don't need right now.", Colors.DIM))
|
||||
print()
|
||||
@@ -942,19 +1078,26 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
pinfo = PLATFORMS[pkey]
|
||||
current = _get_platform_tools(config, pkey)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
if len(platform_keys) > 1:
|
||||
platform_choices.append("Configure all platforms (global)")
|
||||
platform_choices.append("Reconfigure an existing tool's provider or API key")
|
||||
|
||||
# Show MCP option if any MCP servers are configured
|
||||
_has_mcp = bool(config.get("mcp_servers"))
|
||||
if _has_mcp:
|
||||
platform_choices.append("Configure MCP server tools")
|
||||
|
||||
platform_choices.append("Done")
|
||||
|
||||
# Index offsets for the extra options after per-platform entries
|
||||
_global_idx = len(platform_keys) if len(platform_keys) > 1 else -1
|
||||
_reconfig_idx = len(platform_keys) + (1 if len(platform_keys) > 1 else 0)
|
||||
_done_idx = _reconfig_idx + 1
|
||||
_mcp_idx = (_reconfig_idx + 1) if _has_mcp else -1
|
||||
_done_idx = _reconfig_idx + (2 if _has_mcp else 1)
|
||||
|
||||
while True:
|
||||
idx = _prompt_choice("Select an option:", platform_choices, default=0)
|
||||
@@ -969,6 +1112,12 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
continue
|
||||
|
||||
# "Configure MCP tools" selected
|
||||
if idx == _mcp_idx:
|
||||
_configure_mcp_tools_interactive(config)
|
||||
print()
|
||||
continue
|
||||
|
||||
# "Configure all platforms (global)" selected
|
||||
if idx == _global_idx:
|
||||
# Use the union of all platforms' current tools as the starting state
|
||||
@@ -985,10 +1134,10 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
if added or removed:
|
||||
print(color(f" {pinfo_inner['label']}:", Colors.DIM))
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
# Configure API keys for newly enabled tools
|
||||
for ts_key in sorted(added):
|
||||
@@ -1001,7 +1150,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
# Update choice labels
|
||||
for ci, pk in enumerate(platform_keys):
|
||||
new_count = len(_get_platform_tools(config, pk))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[ci] = f"Configure {PLATFORMS[pk]['label']} ({new_count}/{total} enabled)"
|
||||
else:
|
||||
print(color(" No changes", Colors.DIM))
|
||||
@@ -1023,11 +1172,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
@@ -1046,10 +1195,267 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Update the choice label with new count
|
||||
new_count = len(_get_platform_tools(config, pkey))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[idx] = f"Configure {pinfo['label']} ({new_count}/{total} enabled)"
|
||||
|
||||
print()
|
||||
print(color(" Tool configuration saved to ~/.hermes/config.yaml", Colors.DIM))
|
||||
print(color(" Changes take effect on next 'hermes' or gateway restart.", Colors.DIM))
|
||||
print()
|
||||
|
||||
|
||||
# ─── MCP Tools Interactive Configuration ─────────────────────────────────────
|
||||
|
||||
|
||||
def _configure_mcp_tools_interactive(config: dict):
|
||||
"""Probe MCP servers for available tools and let user toggle them on/off.
|
||||
|
||||
Connects to each configured MCP server, discovers tools, then shows
|
||||
a per-server curses checklist. Writes changes back as ``tools.exclude``
|
||||
entries in config.yaml.
|
||||
"""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
if not mcp_servers:
|
||||
_print_info("No MCP servers configured.")
|
||||
return
|
||||
|
||||
# Count enabled servers
|
||||
enabled_names = [
|
||||
k for k, v in mcp_servers.items()
|
||||
if v.get("enabled", True) not in (False, "false", "0", "no", "off")
|
||||
]
|
||||
if not enabled_names:
|
||||
_print_info("All MCP servers are disabled.")
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" Discovering tools from MCP servers...", Colors.YELLOW))
|
||||
print(color(f" Connecting to {len(enabled_names)} server(s): {', '.join(enabled_names)}", Colors.DIM))
|
||||
|
||||
try:
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
server_tools = probe_mcp_server_tools()
|
||||
except Exception as exc:
|
||||
_print_error(f"Failed to probe MCP servers: {exc}")
|
||||
return
|
||||
|
||||
if not server_tools:
|
||||
_print_warning("Could not discover tools from any MCP server.")
|
||||
_print_info("Check that server commands/URLs are correct and dependencies are installed.")
|
||||
return
|
||||
|
||||
# Report discovery results
|
||||
failed = [n for n in enabled_names if n not in server_tools]
|
||||
if failed:
|
||||
for name in failed:
|
||||
_print_warning(f" Could not connect to '{name}'")
|
||||
|
||||
total_tools = sum(len(tools) for tools in server_tools.values())
|
||||
print(color(f" Found {total_tools} tool(s) across {len(server_tools)} server(s)", Colors.GREEN))
|
||||
print()
|
||||
|
||||
any_changes = False
|
||||
|
||||
for server_name, tools in server_tools.items():
|
||||
if not tools:
|
||||
_print_info(f" {server_name}: no tools found")
|
||||
continue
|
||||
|
||||
srv_cfg = mcp_servers.get(server_name, {})
|
||||
tools_cfg = srv_cfg.get("tools") or {}
|
||||
include_list = tools_cfg.get("include") or []
|
||||
exclude_list = tools_cfg.get("exclude") or []
|
||||
|
||||
# Build checklist labels
|
||||
labels = []
|
||||
for tool_name, description in tools:
|
||||
desc_short = description[:70] + "..." if len(description) > 70 else description
|
||||
if desc_short:
|
||||
labels.append(f"{tool_name} ({desc_short})")
|
||||
else:
|
||||
labels.append(tool_name)
|
||||
|
||||
# Determine which tools are currently enabled
|
||||
pre_selected: Set[int] = set()
|
||||
tool_names = [t[0] for t in tools]
|
||||
for i, tool_name in enumerate(tool_names):
|
||||
if include_list:
|
||||
# Include mode: only included tools are selected
|
||||
if tool_name in include_list:
|
||||
pre_selected.add(i)
|
||||
elif exclude_list:
|
||||
# Exclude mode: everything except excluded
|
||||
if tool_name not in exclude_list:
|
||||
pre_selected.add(i)
|
||||
else:
|
||||
# No filter: all enabled
|
||||
pre_selected.add(i)
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"MCP Server: {server_name} ({len(tools)} tools)",
|
||||
labels,
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_print_info(f" {server_name}: no changes")
|
||||
continue
|
||||
|
||||
# Compute new exclude list based on unchecked tools
|
||||
new_exclude = [tool_names[i] for i in range(len(tool_names)) if i not in chosen]
|
||||
|
||||
# Update config
|
||||
srv_cfg = mcp_servers.setdefault(server_name, {})
|
||||
tools_cfg = srv_cfg.setdefault("tools", {})
|
||||
|
||||
if new_exclude:
|
||||
tools_cfg["exclude"] = new_exclude
|
||||
# Remove include if present — we're switching to exclude mode
|
||||
tools_cfg.pop("include", None)
|
||||
else:
|
||||
# All tools enabled — clear filters
|
||||
tools_cfg.pop("exclude", None)
|
||||
tools_cfg.pop("include", None)
|
||||
|
||||
enabled_count = len(chosen)
|
||||
disabled_count = len(tools) - enabled_count
|
||||
_print_success(
|
||||
f" {server_name}: {enabled_count} enabled, {disabled_count} disabled"
|
||||
)
|
||||
any_changes = True
|
||||
|
||||
if any_changes:
|
||||
save_config(config)
|
||||
print()
|
||||
print(color(" ✓ MCP tool configuration saved", Colors.GREEN))
|
||||
else:
|
||||
print(color(" No changes to MCP tools", Colors.DIM))
|
||||
|
||||
|
||||
# ─── Non-interactive disable/enable ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _apply_toolset_change(config: dict, platform: str, toolset_names: List[str], action: str):
|
||||
"""Add or remove built-in toolsets for a platform."""
|
||||
enabled = _get_platform_tools(config, platform)
|
||||
if action == "disable":
|
||||
updated = enabled - set(toolset_names)
|
||||
else:
|
||||
updated = enabled | set(toolset_names)
|
||||
_save_platform_tools(config, platform, updated)
|
||||
|
||||
|
||||
def _apply_mcp_change(config: dict, targets: List[str], action: str) -> Set[str]:
|
||||
"""Add or remove specific MCP tools from a server's exclude list.
|
||||
|
||||
Returns the set of server names that were not found in config.
|
||||
"""
|
||||
failed_servers: Set[str] = set()
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
|
||||
for target in targets:
|
||||
server_name, tool_name = target.split(":", 1)
|
||||
if server_name not in mcp_servers:
|
||||
failed_servers.add(server_name)
|
||||
continue
|
||||
tools_cfg = mcp_servers[server_name].setdefault("tools", {})
|
||||
exclude = list(tools_cfg.get("exclude") or [])
|
||||
if action == "disable":
|
||||
if tool_name not in exclude:
|
||||
exclude.append(tool_name)
|
||||
else:
|
||||
exclude = [t for t in exclude if t != tool_name]
|
||||
tools_cfg["exclude"] = exclude
|
||||
|
||||
return failed_servers
|
||||
|
||||
|
||||
def _print_tools_list(enabled_toolsets: set, mcp_servers: dict, platform: str = "cli"):
|
||||
"""Print a summary of enabled/disabled toolsets and MCP tool filters."""
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
builtin_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
print(f"Built-in toolsets ({platform}):")
|
||||
for ts_key, label, _ in effective:
|
||||
if ts_key not in builtin_keys:
|
||||
continue
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
# Plugin toolsets
|
||||
plugin_entries = [(k, l) for k, l, _ in effective if k not in builtin_keys]
|
||||
if plugin_entries:
|
||||
print()
|
||||
print(f"Plugin toolsets ({platform}):")
|
||||
for ts_key, label in plugin_entries:
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
if mcp_servers:
|
||||
print()
|
||||
print("MCP servers:")
|
||||
for srv_name, srv_cfg in mcp_servers.items():
|
||||
tools_cfg = srv_cfg.get("tools") or {}
|
||||
exclude = tools_cfg.get("exclude") or []
|
||||
include = tools_cfg.get("include") or []
|
||||
if include:
|
||||
_print_info(f"{srv_name} [include only: {', '.join(include)}]")
|
||||
elif exclude:
|
||||
_print_info(f"{srv_name} [excluded: {color(', '.join(exclude), Colors.YELLOW)}]")
|
||||
else:
|
||||
_print_info(f"{srv_name} {color('all tools enabled', Colors.DIM)}")
|
||||
|
||||
|
||||
def tools_disable_enable_command(args):
|
||||
"""Enable, disable, or list tools for a platform.
|
||||
|
||||
Built-in toolsets use plain names (e.g. ``web``, ``memory``).
|
||||
MCP tools use ``server:tool`` notation (e.g. ``github:create_issue``).
|
||||
"""
|
||||
action = args.tools_action
|
||||
platform = getattr(args, "platform", "cli")
|
||||
config = load_config()
|
||||
|
||||
if platform not in PLATFORMS:
|
||||
_print_error(f"Unknown platform '{platform}'. Valid: {', '.join(PLATFORMS)}")
|
||||
return
|
||||
|
||||
if action == "list":
|
||||
_print_tools_list(_get_platform_tools(config, platform),
|
||||
config.get("mcp_servers") or {}, platform)
|
||||
return
|
||||
|
||||
targets: List[str] = args.names
|
||||
toolset_targets = [t for t in targets if ":" not in t]
|
||||
mcp_targets = [t for t in targets if ":" in t]
|
||||
|
||||
valid_toolsets = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS} | _get_plugin_toolset_keys()
|
||||
unknown_toolsets = [t for t in toolset_targets if t not in valid_toolsets]
|
||||
if unknown_toolsets:
|
||||
for name in unknown_toolsets:
|
||||
_print_error(f"Unknown toolset '{name}'")
|
||||
toolset_targets = [t for t in toolset_targets if t in valid_toolsets]
|
||||
|
||||
if toolset_targets:
|
||||
_apply_toolset_change(config, platform, toolset_targets, action)
|
||||
|
||||
failed_servers: Set[str] = set()
|
||||
if mcp_targets:
|
||||
failed_servers = _apply_mcp_change(config, mcp_targets, action)
|
||||
for srv in failed_servers:
|
||||
_print_error(f"MCP server '{srv}' not found in config")
|
||||
|
||||
save_config(config)
|
||||
|
||||
successful = [
|
||||
t for t in targets
|
||||
if t not in unknown_toolsets and (":" not in t or t.split(":")[0] not in failed_servers)
|
||||
]
|
||||
if successful:
|
||||
verb = "Disabled" if action == "disable" else "Enabled"
|
||||
_print_success(f"{verb}: {', '.join(successful)}")
|
||||
|
||||
@@ -133,7 +133,13 @@ def uninstall_gateway_service():
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc_name = get_service_name()
|
||||
except Exception:
|
||||
svc_name = "hermes-gateway"
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / f"{svc_name}.service"
|
||||
|
||||
if not service_file.exists():
|
||||
return False
|
||||
@@ -141,14 +147,14 @@ def uninstall_gateway_service():
|
||||
try:
|
||||
# Stop the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", "hermes-gateway"],
|
||||
["systemctl", "--user", "stop", svc_name],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Disable the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", "hermes-gateway"],
|
||||
["systemctl", "--user", "disable", svc_name],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
@@ -8,5 +8,9 @@ OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
|
||||
|
||||
AI_GATEWAY_BASE_URL = "https://ai-gateway.vercel.sh/v1"
|
||||
AI_GATEWAY_MODELS_URL = f"{AI_GATEWAY_BASE_URL}/models"
|
||||
AI_GATEWAY_CHAT_URL = f"{AI_GATEWAY_BASE_URL}/chat/completions"
|
||||
|
||||
NOUS_API_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
NOUS_API_CHAT_URL = f"{NOUS_API_BASE_URL}/chat/completions"
|
||||
|
||||
+358
-227
@@ -18,6 +18,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
@@ -25,7 +26,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 4
|
||||
SCHEMA_VERSION = 5
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -47,6 +48,17 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
billing_provider TEXT,
|
||||
billing_base_url TEXT,
|
||||
billing_mode TEXT,
|
||||
estimated_cost_usd REAL,
|
||||
actual_cost_usd REAL,
|
||||
cost_status TEXT,
|
||||
cost_source TEXT,
|
||||
pricing_version TEXT,
|
||||
title TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
@@ -104,6 +116,7 @@ class SessionDB:
|
||||
self.db_path = db_path or DEFAULT_DB_PATH
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
@@ -152,6 +165,30 @@ class SessionDB:
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 4")
|
||||
if current_version < 5:
|
||||
new_columns = [
|
||||
("cache_read_tokens", "INTEGER DEFAULT 0"),
|
||||
("cache_write_tokens", "INTEGER DEFAULT 0"),
|
||||
("reasoning_tokens", "INTEGER DEFAULT 0"),
|
||||
("billing_provider", "TEXT"),
|
||||
("billing_base_url", "TEXT"),
|
||||
("billing_mode", "TEXT"),
|
||||
("estimated_cost_usd", "REAL"),
|
||||
("actual_cost_usd", "REAL"),
|
||||
("cost_status", "TEXT"),
|
||||
("cost_source", "TEXT"),
|
||||
("pricing_version", "TEXT"),
|
||||
]
|
||||
for name, column_type in new_columns:
|
||||
try:
|
||||
# name and column_type come from the hardcoded tuple above,
|
||||
# not user input. Double-quote identifier escaping is applied
|
||||
# as defense-in-depth; SQLite DDL cannot be parameterized.
|
||||
safe_name = name.replace('"', '""')
|
||||
cursor.execute(f'ALTER TABLE sessions ADD COLUMN "{safe_name}" {column_type}')
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
cursor.execute("UPDATE schema_version SET version = 5")
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
@@ -173,9 +210,10 @@ class SessionDB:
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
with self._lock:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
# =========================================================================
|
||||
# Session lifecycle
|
||||
@@ -192,61 +230,111 @@ class SessionDB:
|
||||
parent_session_id: str = None,
|
||||
) -> str:
|
||||
"""Create a new session record. Returns the session_id."""
|
||||
self._conn.execute(
|
||||
"""INSERT INTO sessions (id, source, user_id, model, model_config,
|
||||
system_prompt, parent_session_id, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
source,
|
||||
user_id,
|
||||
model,
|
||||
json.dumps(model_config) if model_config else None,
|
||||
system_prompt,
|
||||
parent_session_id,
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""INSERT INTO sessions (id, source, user_id, model, model_config,
|
||||
system_prompt, parent_session_id, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
source,
|
||||
user_id,
|
||||
model,
|
||||
json.dumps(model_config) if model_config else None,
|
||||
system_prompt,
|
||||
parent_session_id,
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
return session_id
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
"""Mark a session as ended."""
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
||||
(time.time(), end_reason, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
||||
(time.time(), end_reason, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""Store the full assembled system prompt snapshot."""
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
||||
(system_prompt, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
||||
(system_prompt, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_token_counts(
|
||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Increment token counters and backfill model if not already set."""
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(input_tokens, output_tokens, model, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
cache_write_tokens = cache_write_tokens + ?,
|
||||
reasoning_tokens = reasoning_tokens + ?,
|
||||
estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0),
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
||||
@@ -266,11 +354,12 @@ class SessionDB:
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return None
|
||||
@@ -331,38 +420,42 @@ class SessionDB:
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
with self._lock:
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
self._conn.commit()
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
@@ -379,12 +472,13 @@ class SessionDB:
|
||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
|
||||
if numbered:
|
||||
# Return the most recent numbered variant
|
||||
@@ -409,11 +503,12 @@ class SessionDB:
|
||||
# Find all existing numbered variants
|
||||
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||
(base, f"{escaped} #%"),
|
||||
)
|
||||
existing = [row["title"] for row in cursor.fetchall()]
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||
(base, f"{escaped} #%"),
|
||||
)
|
||||
existing = [row["title"] for row in cursor.fetchall()]
|
||||
|
||||
if not existing:
|
||||
return base # No conflict, use the base name as-is
|
||||
@@ -461,9 +556,11 @@ class SessionDB:
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = (source, limit, offset) if source else (limit, offset)
|
||||
cursor = self._conn.execute(query, params)
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
sessions = []
|
||||
for row in cursor.fetchall():
|
||||
for row in rows:
|
||||
s = dict(row)
|
||||
# Build the preview from the raw substring
|
||||
raw = s.pop("_preview_raw", "").strip()
|
||||
@@ -497,52 +594,54 @@ class SessionDB:
|
||||
Also increments the session's message_count (and tool_call_count
|
||||
if role is 'tool' or tool_calls is present).
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
role,
|
||||
content,
|
||||
tool_call_id,
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
tool_name,
|
||||
time.time(),
|
||||
token_count,
|
||||
finish_reason,
|
||||
),
|
||||
)
|
||||
msg_id = cursor.lastrowid
|
||||
|
||||
# Update counters
|
||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||
# A single assistant message can contain multiple parallel tool calls.
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
if num_tool_calls > 0:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||
(num_tool_calls, session_id),
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||
(session_id,),
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
role,
|
||||
content,
|
||||
tool_call_id,
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
tool_name,
|
||||
time.time(),
|
||||
token_count,
|
||||
finish_reason,
|
||||
),
|
||||
)
|
||||
msg_id = cursor.lastrowid
|
||||
|
||||
self._conn.commit()
|
||||
# Update counters
|
||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||
# A single assistant message can contain multiple parallel tool calls.
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
if num_tool_calls > 0:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||
(num_tool_calls, session_id),
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
return msg_id
|
||||
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages for a session, ordered by timestamp."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
(session_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
(session_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
msg = dict(row)
|
||||
@@ -559,13 +658,15 @@ class SessionDB:
|
||||
Load messages in the OpenAI conversation format (role + content dicts).
|
||||
Used by the gateway to restore conversation history.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT role, content, tool_call_id, tool_calls, tool_name "
|
||||
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
(session_id,),
|
||||
)
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT role, content, tool_call_id, tool_calls, tool_name "
|
||||
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
(session_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
for row in rows:
|
||||
msg = {"role": row["role"], "content": row["content"]}
|
||||
if row["tool_call_id"]:
|
||||
msg["tool_call_id"] = row["tool_call_id"]
|
||||
@@ -592,21 +693,45 @@ class SessionDB:
|
||||
``NOT``) have special meaning. Passing raw user input directly to
|
||||
MATCH can cause ``sqlite3.OperationalError``.
|
||||
|
||||
Strategy: strip characters that are only meaningful as FTS5 operators
|
||||
and would otherwise cause syntax errors. This preserves normal keyword
|
||||
search while preventing crashes on inputs like ``C++``, ``"unterminated``,
|
||||
or ``hello AND``.
|
||||
Strategy:
|
||||
- Preserve properly paired quoted phrases (``"exact phrase"``)
|
||||
- Strip unmatched FTS5-special characters that would cause errors
|
||||
- Wrap unquoted hyphenated terms in quotes so FTS5 matches them
|
||||
as exact phrases instead of splitting on the hyphen
|
||||
"""
|
||||
# Remove FTS5-special characters that are not useful in keyword search
|
||||
sanitized = re.sub(r'[+{}()"^]', " ", query)
|
||||
# Collapse repeated * (e.g. "***") into a single one, and remove
|
||||
# leading * (prefix-only matching requires at least one char before *)
|
||||
# Step 1: Extract balanced double-quoted phrases and protect them
|
||||
# from further processing via numbered placeholders.
|
||||
_quoted_parts: list = []
|
||||
|
||||
def _preserve_quoted(m: re.Match) -> str:
|
||||
_quoted_parts.append(m.group(0))
|
||||
return f"\x00Q{len(_quoted_parts) - 1}\x00"
|
||||
|
||||
sanitized = re.sub(r'"[^"]*"', _preserve_quoted, query)
|
||||
|
||||
# Step 2: Strip remaining (unmatched) FTS5-special characters
|
||||
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
|
||||
|
||||
# Step 3: Collapse repeated * (e.g. "***") into a single one,
|
||||
# and remove leading * (prefix-only needs at least one char before *)
|
||||
sanitized = re.sub(r"\*+", "*", sanitized)
|
||||
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
|
||||
# Remove dangling boolean operators at start/end that would cause
|
||||
# syntax errors (e.g. "hello AND" or "OR world")
|
||||
|
||||
# Step 4: Remove dangling boolean operators at start/end that would
|
||||
# cause syntax errors (e.g. "hello AND" or "OR world")
|
||||
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
|
||||
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
|
||||
|
||||
# Step 5: Wrap unquoted hyphenated terms (e.g. ``chat-send``) in
|
||||
# double quotes. FTS5's tokenizer splits on hyphens, turning
|
||||
# ``chat-send`` into ``chat AND send``. Quoting preserves the
|
||||
# intended phrase match.
|
||||
sanitized = re.sub(r"\b(\w+(?:-\w+)+)\b", r'"\1"', sanitized)
|
||||
|
||||
# Step 6: Restore preserved quoted phrases
|
||||
for i, quoted in enumerate(_quoted_parts):
|
||||
sanitized = sanitized.replace(f"\x00Q{i}\x00", quoted)
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
def search_messages(
|
||||
@@ -636,16 +761,14 @@ class SessionDB:
|
||||
if not query:
|
||||
return []
|
||||
|
||||
if source_filter is None:
|
||||
source_filter = ["cli", "telegram", "discord", "whatsapp", "slack"]
|
||||
|
||||
# Build WHERE clauses dynamically
|
||||
where_clauses = ["messages_fts MATCH ?"]
|
||||
params: list = [query]
|
||||
|
||||
source_placeholders = ",".join("?" for _ in source_filter)
|
||||
where_clauses.append(f"s.source IN ({source_placeholders})")
|
||||
params.extend(source_filter)
|
||||
if source_filter is not None:
|
||||
source_placeholders = ",".join("?" for _ in source_filter)
|
||||
where_clauses.append(f"s.source IN ({source_placeholders})")
|
||||
params.extend(source_filter)
|
||||
|
||||
if role_filter:
|
||||
role_placeholders = ",".join("?" for _ in role_filter)
|
||||
@@ -675,31 +798,33 @@ class SessionDB:
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
try:
|
||||
cursor = self._conn.execute(sql, params)
|
||||
except sqlite3.OperationalError:
|
||||
# FTS5 query syntax error despite sanitization — return empty
|
||||
return []
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# Add surrounding context (1 message before + after each match)
|
||||
for match in matches:
|
||||
with self._lock:
|
||||
try:
|
||||
ctx_cursor = self._conn.execute(
|
||||
"""SELECT role, content FROM messages
|
||||
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
|
||||
ORDER BY id""",
|
||||
(match["session_id"], match["id"], match["id"]),
|
||||
)
|
||||
context_msgs = [
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
||||
for r in ctx_cursor.fetchall()
|
||||
]
|
||||
match["context"] = context_msgs
|
||||
except Exception:
|
||||
match["context"] = []
|
||||
cursor = self._conn.execute(sql, params)
|
||||
except sqlite3.OperationalError:
|
||||
# FTS5 query syntax error despite sanitization — return empty
|
||||
return []
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# Remove full content from result (snippet is enough, saves tokens)
|
||||
# Add surrounding context (1 message before + after each match)
|
||||
for match in matches:
|
||||
try:
|
||||
ctx_cursor = self._conn.execute(
|
||||
"""SELECT role, content FROM messages
|
||||
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
|
||||
ORDER BY id""",
|
||||
(match["session_id"], match["id"], match["id"]),
|
||||
)
|
||||
context_msgs = [
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
||||
for r in ctx_cursor.fetchall()
|
||||
]
|
||||
match["context"] = context_msgs
|
||||
except Exception:
|
||||
match["context"] = []
|
||||
|
||||
# Remove full content from result (snippet is enough, saves tokens)
|
||||
for match in matches:
|
||||
match.pop("content", None)
|
||||
|
||||
return matches
|
||||
@@ -711,17 +836,18 @@ class SessionDB:
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# =========================================================================
|
||||
# Utility
|
||||
@@ -729,23 +855,25 @@ class SessionDB:
|
||||
|
||||
def session_count(self, source: str = None) -> int:
|
||||
"""Count sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def message_count(self, session_id: str = None) -> int:
|
||||
"""Count messages, optionally for a specific session."""
|
||||
if session_id:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
with self._lock:
|
||||
if session_id:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
# =========================================================================
|
||||
# Export and cleanup
|
||||
@@ -773,26 +901,28 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
return True
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
return True
|
||||
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""
|
||||
@@ -802,22 +932,23 @@ class SessionDB:
|
||||
import time as _time
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
for sid in session_ids:
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
for sid in session_ids:
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
|
||||
self._conn.commit()
|
||||
self._conn.commit()
|
||||
return len(session_ids)
|
||||
|
||||
+31
-16
@@ -10,22 +10,30 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
GLOBAL_CONFIG_PATH = Path.home() / ".honcho" / "config.json"
|
||||
from honcho_integration.client import resolve_config_path, GLOBAL_CONFIG_PATH
|
||||
|
||||
HOST = "hermes"
|
||||
|
||||
|
||||
def _config_path() -> Path:
|
||||
"""Return the active Honcho config path (instance-local or global)."""
|
||||
return resolve_config_path()
|
||||
|
||||
|
||||
def _read_config() -> dict:
|
||||
if GLOBAL_CONFIG_PATH.exists():
|
||||
path = _config_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(GLOBAL_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _write_config(cfg: dict) -> None:
|
||||
GLOBAL_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
GLOBAL_CONFIG_PATH.write_text(
|
||||
def _write_config(cfg: dict, path: Path | None = None) -> None:
|
||||
path = path or _config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
json.dumps(cfg, indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
@@ -87,9 +95,14 @@ def cmd_setup(args) -> None:
|
||||
"""Interactive Honcho setup wizard."""
|
||||
cfg = _read_config()
|
||||
|
||||
active_path = _config_path()
|
||||
print("\nHoncho memory setup\n" + "─" * 40)
|
||||
print(" Honcho gives Hermes persistent cross-session memory.")
|
||||
print(" Config is shared with other hosts at ~/.honcho/config.json\n")
|
||||
if active_path != GLOBAL_CONFIG_PATH:
|
||||
print(f" Instance config: {active_path}")
|
||||
else:
|
||||
print(" Config is shared with other hosts at ~/.honcho/config.json")
|
||||
print()
|
||||
|
||||
if not _ensure_sdk_installed():
|
||||
return
|
||||
@@ -162,10 +175,10 @@ def cmd_setup(args) -> None:
|
||||
hermes_host["recallMode"] = new_recall
|
||||
|
||||
# Session strategy
|
||||
current_strat = hermes_host.get("sessionStrategy") or cfg.get("sessionStrategy", "per-session")
|
||||
current_strat = hermes_host.get("sessionStrategy") or cfg.get("sessionStrategy", "per-directory")
|
||||
print(f"\n Session strategy options:")
|
||||
print(" per-session — new Honcho session each run, named by Hermes session ID (default)")
|
||||
print(" per-directory — one session per working directory")
|
||||
print(" per-directory — one session per working directory (default)")
|
||||
print(" per-session — new Honcho session each run, named by Hermes session ID")
|
||||
print(" per-repo — one session per git repository (uses repo root name)")
|
||||
print(" global — single session across all directories")
|
||||
new_strat = _prompt("Session strategy", default=current_strat)
|
||||
@@ -176,7 +189,7 @@ def cmd_setup(args) -> None:
|
||||
hermes_host.setdefault("saveMessages", True)
|
||||
|
||||
_write_config(cfg)
|
||||
print(f"\n Config written to {GLOBAL_CONFIG_PATH}")
|
||||
print(f"\n Config written to {active_path}")
|
||||
|
||||
# Test connection
|
||||
print(" Testing connection... ", end="", flush=True)
|
||||
@@ -223,8 +236,10 @@ def cmd_status(args) -> None:
|
||||
|
||||
cfg = _read_config()
|
||||
|
||||
active_path = _config_path()
|
||||
|
||||
if not cfg:
|
||||
print(" No Honcho config found at ~/.honcho/config.json")
|
||||
print(f" No Honcho config found at {active_path}")
|
||||
print(" Run 'hermes honcho setup' to configure.\n")
|
||||
return
|
||||
|
||||
@@ -243,7 +258,7 @@ def cmd_status(args) -> None:
|
||||
print(f" API key: {masked}")
|
||||
print(f" Workspace: {hcfg.workspace_id}")
|
||||
print(f" Host: {hcfg.host}")
|
||||
print(f" Config path: {GLOBAL_CONFIG_PATH}")
|
||||
print(f" Config path: {active_path}")
|
||||
print(f" AI peer: {hcfg.ai_peer}")
|
||||
print(f" User peer: {hcfg.peer_name or 'not set'}")
|
||||
print(f" Session key: {hcfg.resolve_session_name()}")
|
||||
@@ -275,7 +290,7 @@ def cmd_sessions(args) -> None:
|
||||
if not sessions:
|
||||
print(" No session mappings configured.\n")
|
||||
print(" Add one with: hermes honcho map <session-name>")
|
||||
print(" Or edit ~/.honcho/config.json directly.\n")
|
||||
print(f" Or edit {_config_path()} directly.\n")
|
||||
return
|
||||
|
||||
cwd = os.getcwd()
|
||||
@@ -361,7 +376,7 @@ def cmd_peer(args) -> None:
|
||||
|
||||
if changed:
|
||||
_write_config(cfg)
|
||||
print(f" Saved to {GLOBAL_CONFIG_PATH}\n")
|
||||
print(f" Saved to {_config_path()}\n")
|
||||
|
||||
|
||||
def cmd_mode(args) -> None:
|
||||
@@ -434,7 +449,7 @@ def cmd_tokens(args) -> None:
|
||||
|
||||
if changed:
|
||||
_write_config(cfg)
|
||||
print(f" Saved to {GLOBAL_CONFIG_PATH}\n")
|
||||
print(f" Saved to {_config_path()}\n")
|
||||
|
||||
|
||||
def cmd_identity(args) -> None:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Honcho client initialization and configuration.
|
||||
|
||||
Reads the global ~/.honcho/config.json when available, falling back
|
||||
to environment variables.
|
||||
Resolution order for config file:
|
||||
1. $HERMES_HOME/honcho.json (instance-local, enables isolated Hermes instances)
|
||||
2. ~/.honcho/config.json (global, shared across all Honcho-enabled apps)
|
||||
3. Environment variables (HONCHO_API_KEY, HONCHO_ENVIRONMENT)
|
||||
|
||||
Resolution order for host-specific settings:
|
||||
1. Explicit host block fields (always win)
|
||||
@@ -27,6 +29,24 @@ GLOBAL_CONFIG_PATH = Path.home() / ".honcho" / "config.json"
|
||||
HOST = "hermes"
|
||||
|
||||
|
||||
def _get_hermes_home() -> Path:
|
||||
"""Get HERMES_HOME without importing hermes_cli (avoids circular deps)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def resolve_config_path() -> Path:
|
||||
"""Return the active Honcho config path.
|
||||
|
||||
Checks $HERMES_HOME/honcho.json first (instance-local), then falls back
|
||||
to ~/.honcho/config.json (global). Returns the global path if neither
|
||||
exists (for first-time setup writes).
|
||||
"""
|
||||
local_path = _get_hermes_home() / "honcho.json"
|
||||
if local_path.exists():
|
||||
return local_path
|
||||
return GLOBAL_CONFIG_PATH
|
||||
|
||||
|
||||
_RECALL_MODE_ALIASES = {"auto": "hybrid"}
|
||||
_VALID_RECALL_MODES = {"hybrid", "context", "tools"}
|
||||
|
||||
@@ -69,6 +89,8 @@ class HonchoClientConfig:
|
||||
workspace_id: str = "hermes"
|
||||
api_key: str | None = None
|
||||
environment: str = "production"
|
||||
# Optional base URL for self-hosted Honcho (overrides environment mapping)
|
||||
base_url: str | None = None
|
||||
# Identity
|
||||
peer_name: str | None = None
|
||||
ai_peer: str = "hermes"
|
||||
@@ -105,20 +127,27 @@ class HonchoClientConfig:
|
||||
# "tools" — Honcho tools only, no auto-injected context
|
||||
recall_mode: str = "hybrid"
|
||||
# Session resolution
|
||||
session_strategy: str = "per-session"
|
||||
session_strategy: str = "per-directory"
|
||||
session_peer_prefix: bool = False
|
||||
sessions: dict[str, str] = field(default_factory=dict)
|
||||
# Raw global config for anything else consumers need
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
# True when Honcho was explicitly configured for this host (hosts.hermes
|
||||
# block exists or enabled was set explicitly), vs auto-enabled from a
|
||||
# stray HONCHO_API_KEY env var.
|
||||
explicitly_configured: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, workspace_id: str = "hermes") -> HonchoClientConfig:
|
||||
"""Create config from environment variables (fallback)."""
|
||||
api_key = os.environ.get("HONCHO_API_KEY")
|
||||
base_url = os.environ.get("HONCHO_BASE_URL", "").strip() or None
|
||||
return cls(
|
||||
workspace_id=workspace_id,
|
||||
api_key=os.environ.get("HONCHO_API_KEY"),
|
||||
api_key=api_key,
|
||||
environment=os.environ.get("HONCHO_ENVIRONMENT", "production"),
|
||||
enabled=True,
|
||||
base_url=base_url,
|
||||
enabled=bool(api_key or base_url),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -127,11 +156,11 @@ class HonchoClientConfig:
|
||||
host: str = HOST,
|
||||
config_path: Path | None = None,
|
||||
) -> HonchoClientConfig:
|
||||
"""Create config from ~/.honcho/config.json.
|
||||
"""Create config from the resolved Honcho config path.
|
||||
|
||||
Falls back to environment variables if the file doesn't exist.
|
||||
Resolution: $HERMES_HOME/honcho.json -> ~/.honcho/config.json -> env vars.
|
||||
"""
|
||||
path = config_path or GLOBAL_CONFIG_PATH
|
||||
path = config_path or resolve_config_path()
|
||||
if not path.exists():
|
||||
logger.debug("No global Honcho config at %s, falling back to env", path)
|
||||
return cls.from_env()
|
||||
@@ -143,6 +172,9 @@ class HonchoClientConfig:
|
||||
return cls.from_env()
|
||||
|
||||
host_block = (raw.get("hosts") or {}).get(host, {})
|
||||
# A hosts.hermes block or explicit enabled flag means the user
|
||||
# intentionally configured Honcho for this host.
|
||||
_explicitly_configured = bool(host_block) or raw.get("enabled") is True
|
||||
|
||||
# Explicit host block fields win, then flat/global, then defaults
|
||||
workspace = (
|
||||
@@ -168,8 +200,14 @@ class HonchoClientConfig:
|
||||
or raw.get("environment", "production")
|
||||
)
|
||||
|
||||
# Auto-enable when API key is present (unless explicitly disabled)
|
||||
# Host-level enabled wins, then root-level, then auto-enable if key exists.
|
||||
base_url = (
|
||||
raw.get("baseUrl")
|
||||
or os.environ.get("HONCHO_BASE_URL", "").strip()
|
||||
or None
|
||||
)
|
||||
|
||||
# Auto-enable when API key or base_url is present (unless explicitly disabled)
|
||||
# Host-level enabled wins, then root-level, then auto-enable if key/url exists.
|
||||
host_enabled = host_block.get("enabled")
|
||||
root_enabled = raw.get("enabled")
|
||||
if host_enabled is not None:
|
||||
@@ -177,8 +215,8 @@ class HonchoClientConfig:
|
||||
elif root_enabled is not None:
|
||||
enabled = root_enabled
|
||||
else:
|
||||
# Not explicitly set anywhere -> auto-enable if API key exists
|
||||
enabled = bool(api_key)
|
||||
# Not explicitly set anywhere -> auto-enable if API key or base_url exists
|
||||
enabled = bool(api_key or base_url)
|
||||
|
||||
# write_frequency: accept int or string
|
||||
raw_wf = (
|
||||
@@ -198,7 +236,7 @@ class HonchoClientConfig:
|
||||
# sessionStrategy / sessionPeerPrefix: host first, root fallback
|
||||
session_strategy = (
|
||||
host_block.get("sessionStrategy")
|
||||
or raw.get("sessionStrategy", "per-session")
|
||||
or raw.get("sessionStrategy", "per-directory")
|
||||
)
|
||||
host_prefix = host_block.get("sessionPeerPrefix")
|
||||
session_peer_prefix = (
|
||||
@@ -211,6 +249,7 @@ class HonchoClientConfig:
|
||||
workspace_id=workspace,
|
||||
api_key=api_key,
|
||||
environment=environment,
|
||||
base_url=base_url,
|
||||
peer_name=host_block.get("peerName") or raw.get("peerName"),
|
||||
ai_peer=ai_peer,
|
||||
linked_hosts=linked_hosts,
|
||||
@@ -241,6 +280,7 @@ class HonchoClientConfig:
|
||||
session_peer_prefix=session_peer_prefix,
|
||||
sessions=raw.get("sessions", {}),
|
||||
raw=raw,
|
||||
explicitly_configured=_explicitly_configured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -306,7 +346,7 @@ class HonchoClientConfig:
|
||||
return f"{self.peer_name}-{base}"
|
||||
return base
|
||||
|
||||
# per-directory: one Honcho session per working directory
|
||||
# per-directory: one Honcho session per working directory (default)
|
||||
if self.session_strategy in ("per-directory", "per-session"):
|
||||
base = Path(cwd).name
|
||||
if self.session_peer_prefix and self.peer_name:
|
||||
@@ -345,11 +385,12 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
|
||||
if config is None:
|
||||
config = HonchoClientConfig.from_global_config()
|
||||
|
||||
if not config.api_key:
|
||||
if not config.api_key and not config.base_url:
|
||||
raise ValueError(
|
||||
"Honcho API key not found. "
|
||||
"Get your API key at https://app.honcho.dev, "
|
||||
"then run 'hermes honcho setup' or set HONCHO_API_KEY."
|
||||
"then run 'hermes honcho setup' or set HONCHO_API_KEY. "
|
||||
"For local instances, set HONCHO_BASE_URL instead."
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -360,13 +401,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
|
||||
"Install it with: pip install honcho-ai"
|
||||
)
|
||||
|
||||
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
|
||||
# Allow config.yaml honcho.base_url to override the SDK's environment
|
||||
# mapping, enabling remote self-hosted Honcho deployments without
|
||||
# requiring the server to live on localhost.
|
||||
resolved_base_url = config.base_url
|
||||
if not resolved_base_url:
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
hermes_cfg = load_config()
|
||||
honcho_cfg = hermes_cfg.get("honcho", {})
|
||||
if isinstance(honcho_cfg, dict):
|
||||
resolved_base_url = honcho_cfg.get("base_url", "").strip() or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_honcho_client = Honcho(
|
||||
workspace_id=config.workspace_id,
|
||||
api_key=config.api_key,
|
||||
environment=config.environment,
|
||||
)
|
||||
if resolved_base_url:
|
||||
logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id)
|
||||
else:
|
||||
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
|
||||
|
||||
kwargs: dict = {
|
||||
"workspace_id": config.workspace_id,
|
||||
"api_key": config.api_key,
|
||||
"environment": config.environment,
|
||||
}
|
||||
if resolved_base_url:
|
||||
kwargs["base_url"] = resolved_base_url
|
||||
|
||||
_honcho_client = Honcho(**kwargs)
|
||||
|
||||
return _honcho_client
|
||||
|
||||
|
||||
@@ -927,6 +927,11 @@ class HonchoSessionManager:
|
||||
return False
|
||||
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
honcho_session = self._sessions_cache.get(session.honcho_session_id)
|
||||
if not honcho_session:
|
||||
logger.warning("No Honcho session cached for '%s', skipping AI seed", session_key)
|
||||
return False
|
||||
|
||||
try:
|
||||
wrapped = (
|
||||
f"<ai_identity_seed>\n"
|
||||
@@ -935,7 +940,7 @@ class HonchoSessionManager:
|
||||
f"{content.strip()}\n"
|
||||
f"</ai_identity_seed>"
|
||||
)
|
||||
assistant_peer.add_message("assistant", wrapped)
|
||||
honcho_session.add_messages([assistant_peer.message(wrapped)])
|
||||
logger.info("Seeded AI identity from '%s' into %s", source, session_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -339,6 +339,7 @@ class MiniSWERunner:
|
||||
|
||||
# Add tool calls in XML format
|
||||
for tool_call in msg["tool_calls"]:
|
||||
if not tool_call or not isinstance(tool_call, dict): continue
|
||||
try:
|
||||
arguments = json.loads(tool_call["function"]["arguments"]) \
|
||||
if isinstance(tool_call["function"]["arguments"], str) \
|
||||
|
||||
+138
-15
@@ -22,8 +22,8 @@ Public API (signatures preserved from the original 2,400-line version):
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import registry
|
||||
@@ -36,6 +36,48 @@ logger = logging.getLogger(__name__)
|
||||
# Async Bridging (single source of truth -- used by registry.dispatch too)
|
||||
# =============================================================================
|
||||
|
||||
_tool_loop = None # persistent loop for the main (CLI) thread
|
||||
_tool_loop_lock = threading.Lock()
|
||||
_worker_thread_local = threading.local() # per-worker-thread persistent loops
|
||||
|
||||
|
||||
def _get_tool_loop():
|
||||
"""Return a long-lived event loop for running async tool handlers.
|
||||
|
||||
Using a persistent loop (instead of asyncio.run() which creates and
|
||||
*closes* a fresh loop every time) prevents "Event loop is closed"
|
||||
errors that occur when cached httpx/AsyncOpenAI clients attempt to
|
||||
close their transport on a dead loop during garbage collection.
|
||||
"""
|
||||
global _tool_loop
|
||||
with _tool_loop_lock:
|
||||
if _tool_loop is None or _tool_loop.is_closed():
|
||||
_tool_loop = asyncio.new_event_loop()
|
||||
return _tool_loop
|
||||
|
||||
|
||||
def _get_worker_loop():
|
||||
"""Return a persistent event loop for the current worker thread.
|
||||
|
||||
Each worker thread (e.g., delegate_task's ThreadPoolExecutor threads)
|
||||
gets its own long-lived loop stored in thread-local storage. This
|
||||
prevents the "Event loop is closed" errors that occurred when
|
||||
asyncio.run() was used per-call: asyncio.run() creates a loop, runs
|
||||
the coroutine, then *closes* the loop — but cached httpx/AsyncOpenAI
|
||||
clients remain bound to that now-dead loop and raise RuntimeError
|
||||
during garbage collection or subsequent use.
|
||||
|
||||
By keeping the loop alive for the thread's lifetime, cached clients
|
||||
stay valid and their cleanup runs on a live loop.
|
||||
"""
|
||||
loop = getattr(_worker_thread_local, 'loop', None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
_worker_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync context.
|
||||
|
||||
@@ -44,6 +86,15 @@ def _run_async(coro):
|
||||
disposable thread so asyncio.run() can create its own loop without
|
||||
conflicting.
|
||||
|
||||
For the common CLI path (no running loop), we use a persistent event
|
||||
loop so that cached async clients (httpx / AsyncOpenAI) remain bound
|
||||
to a live loop and don't trigger "Event loop is closed" on GC.
|
||||
|
||||
When called from a worker thread (parallel tool execution), we use a
|
||||
per-thread persistent loop to avoid both contention with the main
|
||||
thread's shared loop AND the "Event loop is closed" errors caused by
|
||||
asyncio.run()'s create-and-destroy lifecycle.
|
||||
|
||||
This is the single source of truth for sync->async bridging in tool
|
||||
handlers. The RL paths (agent_loop.py, tool_context.py) also provide
|
||||
outer thread-pool wrapping as defense-in-depth, but each handler is
|
||||
@@ -55,11 +106,23 @@ def _run_async(coro):
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Inside an async context (gateway, RL env) — run in a fresh thread.
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=300)
|
||||
return asyncio.run(coro)
|
||||
|
||||
# If we're on a worker thread (e.g., parallel tool execution in
|
||||
# delegate_task), use a per-thread persistent loop. This avoids
|
||||
# contention with the main thread's shared loop while keeping cached
|
||||
# httpx/AsyncOpenAI clients bound to a live loop for the thread's
|
||||
# lifetime — preventing "Event loop is closed" on GC cleanup.
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
worker_loop = _get_worker_loop()
|
||||
return worker_loop.run_until_complete(coro)
|
||||
|
||||
tool_loop = _get_tool_loop()
|
||||
return tool_loop.run_until_complete(coro)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -101,7 +164,7 @@ def _discover_tools():
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
except Exception as e:
|
||||
logger.debug("Could not import %s: %s", mod_name, e)
|
||||
logger.warning("Could not import tool module %s: %s", mod_name, e)
|
||||
|
||||
|
||||
_discover_tools()
|
||||
@@ -113,6 +176,13 @@ try:
|
||||
except Exception as e:
|
||||
logger.debug("MCP tool discovery failed: %s", e)
|
||||
|
||||
# Plugin tool discovery (user/project/pip plugins)
|
||||
try:
|
||||
from hermes_cli.plugins import discover_plugins
|
||||
discover_plugins()
|
||||
except Exception as e:
|
||||
logger.debug("Plugin discovery failed: %s", e)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backward-compat constants (built once after discovery)
|
||||
@@ -142,7 +212,7 @@ _LEGACY_TOOLSET_MAP = {
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
"browser_vision", "browser_console"
|
||||
],
|
||||
"cronjob_tools": ["cronjob"],
|
||||
"rl_tools": [
|
||||
@@ -222,21 +292,54 @@ def get_tool_definitions(
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
# Plugin-registered tools are now resolved through the normal toolset
|
||||
# path — validate_toolset() / resolve_toolset() / get_all_toolsets()
|
||||
# all check the tool registry for plugin-provided toolsets. No bypass
|
||||
# needed; plugins respect enabled_toolsets / disabled_toolsets like any
|
||||
# other toolset.
|
||||
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
# The set of tool names that actually passed check_fn filtering.
|
||||
# Use this (not tools_to_include) for any downstream schema that references
|
||||
# other tools by name — otherwise the model sees tools mentioned in
|
||||
# descriptions that don't actually exist, and hallucinates calls to them.
|
||||
available_tool_names = {t["function"]["name"] for t in filtered_tools}
|
||||
|
||||
# Rebuild execute_code schema to only list sandbox tools that are actually
|
||||
# enabled. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
# available. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the API key isn't configured or the toolset is
|
||||
# disabled (#560-discord).
|
||||
if "execute_code" in available_tool_names:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & available_tool_names
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
break
|
||||
|
||||
# Strip web tool cross-references from browser_navigate description when
|
||||
# web_search / web_extract are not available. The static schema says
|
||||
# "prefer web_search or web_extract" which causes the model to hallucinate
|
||||
# those tools when they're missing.
|
||||
if "browser_navigate" in available_tool_names:
|
||||
web_tools_available = {"web_search", "web_extract"} & available_tool_names
|
||||
if not web_tools_available:
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "browser_navigate":
|
||||
desc = td["function"].get("description", "")
|
||||
desc = desc.replace(
|
||||
" For simple information retrieval, prefer web_search or web_extract (faster, cheaper).",
|
||||
"",
|
||||
)
|
||||
filtered_tools[i] = {
|
||||
"type": "function",
|
||||
"function": {**td["function"], "description": desc},
|
||||
}
|
||||
break
|
||||
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
@@ -259,6 +362,7 @@ def get_tool_definitions(
|
||||
# The registry still holds their schemas; dispatch just returns a stub error
|
||||
# so if something slips through, the LLM sees a sensible message.
|
||||
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
@@ -267,6 +371,8 @@ def handle_function_call(
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
honcho_manager: Optional[Any] = None,
|
||||
honcho_session_key: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to the tool registry.
|
||||
@@ -286,7 +392,6 @@ def handle_function_call(
|
||||
"""
|
||||
# Notify the read-loop tracker when a non-read/search tool runs,
|
||||
# so the *consecutive* counter resets (reads after other work are fine).
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
if function_name not in _READ_SEARCH_TOOLS:
|
||||
try:
|
||||
from tools.file_tools import notify_other_tool_call
|
||||
@@ -298,21 +403,39 @@ def handle_function_call(
|
||||
if function_name in _AGENT_LOOP_TOOLS:
|
||||
return json.dumps({"error": f"{function_name} must be handled by the agent loop"})
|
||||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("pre_tool_call", tool_name=function_name, args=function_args, task_id=task_id or "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if function_name == "execute_code":
|
||||
# Prefer the caller-provided list so subagents can't overwrite
|
||||
# the parent's tool set via the process-global.
|
||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||
return registry.dispatch(
|
||||
result = registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
enabled_tools=sandbox_enabled,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
else:
|
||||
result = registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
|
||||
return registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("post_tool_call", tool_name=function_name, args=function_args, result=result, task_id=task_id or "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
---
|
||||
name: base
|
||||
description: Query Base (Ethereum L2) blockchain data with USD pricing — wallet balances, token info, transaction details, gas analysis, contract inspection, whale detection, and live network stats. Uses Base RPC + CoinGecko. No API key required.
|
||||
version: 0.1.0
|
||||
author: youssefea
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Base, Blockchain, Crypto, Web3, RPC, DeFi, EVM, L2, Ethereum]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Base Blockchain Skill
|
||||
|
||||
Query Base (Ethereum L2) on-chain data enriched with USD pricing via CoinGecko.
|
||||
8 commands: wallet portfolio, token info, transactions, gas analysis,
|
||||
contract inspection, whale detection, network stats, and price lookup.
|
||||
|
||||
No API key needed. Uses only Python standard library (urllib, json, argparse).
|
||||
|
||||
---
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks for a Base wallet balance, token holdings, or portfolio value
|
||||
- User wants to inspect a specific transaction by hash
|
||||
- User wants ERC-20 token metadata, price, supply, or market cap
|
||||
- User wants to understand Base gas costs and L1 data fees
|
||||
- User wants to inspect a contract (ERC type detection, proxy resolution)
|
||||
- User wants to find large ETH transfers (whale detection)
|
||||
- User wants Base network health, gas price, or ETH price
|
||||
- User asks "what's the price of USDC/AERO/DEGEN/ETH?"
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The helper script uses only Python standard library (urllib, json, argparse).
|
||||
No external packages required.
|
||||
|
||||
Pricing data comes from CoinGecko's free API (no key needed, rate-limited
|
||||
to ~10-30 requests/minute). For faster lookups, use `--no-prices` flag.
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
RPC endpoint (default): https://mainnet.base.org
|
||||
Override: export BASE_RPC_URL=https://your-private-rpc.com
|
||||
|
||||
Helper script path: ~/.hermes/skills/blockchain/base/scripts/base_client.py
|
||||
|
||||
```
|
||||
python3 base_client.py wallet <address> [--limit N] [--all] [--no-prices]
|
||||
python3 base_client.py tx <hash>
|
||||
python3 base_client.py token <contract_address>
|
||||
python3 base_client.py gas
|
||||
python3 base_client.py contract <address>
|
||||
python3 base_client.py whales [--min-eth N]
|
||||
python3 base_client.py stats
|
||||
python3 base_client.py price <contract_address_or_symbol>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Procedure
|
||||
|
||||
### 0. Setup Check
|
||||
|
||||
```bash
|
||||
python3 --version
|
||||
|
||||
# Optional: set a private RPC for better rate limits
|
||||
export BASE_RPC_URL="https://mainnet.base.org"
|
||||
|
||||
# Confirm connectivity
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py stats
|
||||
```
|
||||
|
||||
### 1. Wallet Portfolio
|
||||
|
||||
Get ETH balance and ERC-20 token holdings with USD values.
|
||||
Checks ~15 well-known Base tokens (USDC, WETH, AERO, DEGEN, etc.)
|
||||
via on-chain `balanceOf` calls. Tokens sorted by value, dust filtered.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py \
|
||||
wallet 0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045
|
||||
```
|
||||
|
||||
Flags:
|
||||
- `--limit N` — show top N tokens (default: 20)
|
||||
- `--all` — show all tokens, no dust filter, no limit
|
||||
- `--no-prices` — skip CoinGecko price lookups (faster, RPC-only)
|
||||
|
||||
Output includes: ETH balance + USD value, token list with prices sorted
|
||||
by value, dust count, total portfolio value in USD.
|
||||
|
||||
Note: Only checks known tokens. Unknown ERC-20s are not discovered.
|
||||
Use the `token` command with a specific contract address for any token.
|
||||
|
||||
### 2. Transaction Details
|
||||
|
||||
Inspect a full transaction by its hash. Shows ETH value transferred,
|
||||
gas used, fee in ETH/USD, status, and decoded ERC-20/ERC-721 transfers.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py \
|
||||
tx 0xabc123...your_tx_hash_here
|
||||
```
|
||||
|
||||
Output: hash, block, from, to, value (ETH + USD), gas price, gas used,
|
||||
fee, status, contract creation address (if any), token transfers.
|
||||
|
||||
### 3. Token Info
|
||||
|
||||
Get ERC-20 token metadata: name, symbol, decimals, total supply, price,
|
||||
market cap, and contract code size.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py \
|
||||
token 0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913
|
||||
```
|
||||
|
||||
Output: name, symbol, decimals, total supply, price, market cap.
|
||||
Reads name/symbol/decimals directly from the contract via eth_call.
|
||||
|
||||
### 4. Gas Analysis
|
||||
|
||||
Detailed gas analysis with cost estimates for common operations.
|
||||
Shows current gas price, base fee trends over 10 blocks, block
|
||||
utilization, and estimated costs for ETH transfers, ERC-20 transfers,
|
||||
and swaps.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py gas
|
||||
```
|
||||
|
||||
Output: current gas price, base fee, block utilization, 10-block trend,
|
||||
cost estimates in ETH and USD.
|
||||
|
||||
Note: Base is an L2 — actual transaction costs include an L1 data
|
||||
posting fee that depends on calldata size and L1 gas prices. The
|
||||
estimates shown are for L2 execution only.
|
||||
|
||||
### 5. Contract Inspection
|
||||
|
||||
Inspect an address: determine if it's an EOA or contract, detect
|
||||
ERC-20/ERC-721/ERC-1155 interfaces, resolve EIP-1967 proxy
|
||||
implementation addresses.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py \
|
||||
contract 0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913
|
||||
```
|
||||
|
||||
Output: is_contract, code size, ETH balance, detected interfaces
|
||||
(ERC-20, ERC-721, ERC-1155), ERC-20 metadata, proxy implementation
|
||||
address.
|
||||
|
||||
### 6. Whale Detector
|
||||
|
||||
Scan the most recent block for large ETH transfers with USD values.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py \
|
||||
whales --min-eth 1.0
|
||||
```
|
||||
|
||||
Note: scans the latest block only — point-in-time snapshot, not historical.
|
||||
Default threshold is 1.0 ETH (lower than Solana's default since ETH
|
||||
values are higher).
|
||||
|
||||
### 7. Network Stats
|
||||
|
||||
Live Base network health: latest block, chain ID, gas price, base fee,
|
||||
block utilization, transaction count, and ETH price.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py stats
|
||||
```
|
||||
|
||||
### 8. Price Lookup
|
||||
|
||||
Quick price check for any token by contract address or known symbol.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py price ETH
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py price USDC
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py price AERO
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py price DEGEN
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py price 0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913
|
||||
```
|
||||
|
||||
Known symbols: ETH, WETH, USDC, cbETH, AERO, DEGEN, TOSHI, BRETT,
|
||||
WELL, wstETH, rETH, cbBTC.
|
||||
|
||||
---
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **CoinGecko rate-limits** — free tier allows ~10-30 requests/minute.
|
||||
Price lookups use 1 request per token. Use `--no-prices` for speed.
|
||||
- **Public RPC rate-limits** — Base's public RPC limits requests.
|
||||
For production use, set BASE_RPC_URL to a private endpoint
|
||||
(Alchemy, QuickNode, Infura).
|
||||
- **Wallet shows known tokens only** — unlike Solana, EVM chains have no
|
||||
built-in "get all tokens" RPC. The wallet command checks ~15 popular
|
||||
Base tokens via `balanceOf`. Unknown ERC-20s won't appear. Use the
|
||||
`token` command for any specific contract.
|
||||
- **Token names read from contract** — if a contract doesn't implement
|
||||
`name()` or `symbol()`, these fields may be empty. Known tokens have
|
||||
hardcoded labels as fallback.
|
||||
- **Gas estimates are L2 only** — Base transaction costs include an L1
|
||||
data posting fee (depends on calldata size and L1 gas prices). The gas
|
||||
command estimates L2 execution cost only.
|
||||
- **Whale detector scans latest block only** — not historical. Results
|
||||
vary by the moment you query. Default threshold is 1.0 ETH.
|
||||
- **Proxy detection** — only EIP-1967 proxies are detected. Other proxy
|
||||
patterns (EIP-1167 minimal proxy, custom storage slots) are not checked.
|
||||
- **Retry on 429** — both RPC and CoinGecko calls retry up to 2 times
|
||||
with exponential backoff on rate-limit errors.
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Should print Base chain ID (8453), latest block, gas price, and ETH price
|
||||
python3 ~/.hermes/skills/blockchain/base/scripts/base_client.py stats
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,116 @@
|
||||
---
|
||||
name: blender-mcp
|
||||
description: Control Blender directly from Hermes via socket connection to the blender-mcp addon. Create 3D objects, materials, animations, and run arbitrary Blender Python (bpy) code. Use when user wants to create or modify anything in Blender.
|
||||
version: 1.0.0
|
||||
requires: Blender 4.3+ (desktop instance required, headless not supported)
|
||||
author: alireza78a
|
||||
tags: [blender, 3d, animation, modeling, bpy, mcp]
|
||||
---
|
||||
|
||||
# Blender MCP
|
||||
|
||||
Control a running Blender instance from Hermes via socket on TCP port 9876.
|
||||
|
||||
## Setup (one-time)
|
||||
|
||||
### 1. Install the Blender addon
|
||||
|
||||
curl -sL https://raw.githubusercontent.com/ahujasid/blender-mcp/main/addon.py -o ~/Desktop/blender_mcp_addon.py
|
||||
|
||||
In Blender:
|
||||
Edit > Preferences > Add-ons > Install > select blender_mcp_addon.py
|
||||
Enable "Interface: Blender MCP"
|
||||
|
||||
### 2. Start the socket server in Blender
|
||||
|
||||
Press N in Blender viewport to open sidebar.
|
||||
Find "BlenderMCP" tab and click "Start Server".
|
||||
|
||||
### 3. Verify connection
|
||||
|
||||
nc -z -w2 localhost 9876 && echo "OPEN" || echo "CLOSED"
|
||||
|
||||
## Protocol
|
||||
|
||||
Plain UTF-8 JSON over TCP -- no length prefix.
|
||||
|
||||
Send: {"type": "<command>", "params": {<kwargs>}}
|
||||
Receive: {"status": "success", "result": <value>}
|
||||
{"status": "error", "message": "<reason>"}
|
||||
|
||||
## Available Commands
|
||||
|
||||
| type | params | description |
|
||||
|-------------------------|-------------------|---------------------------------|
|
||||
| execute_code | code (str) | Run arbitrary bpy Python code |
|
||||
| get_scene_info | (none) | List all objects in scene |
|
||||
| get_object_info | object_name (str) | Details on a specific object |
|
||||
| get_viewport_screenshot | (none) | Screenshot of current viewport |
|
||||
|
||||
## Python Helper
|
||||
|
||||
Use this inside execute_code tool calls:
|
||||
|
||||
import socket, json
|
||||
|
||||
def blender_exec(code: str, host="localhost", port=9876, timeout=15):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((host, port))
|
||||
s.settimeout(timeout)
|
||||
payload = json.dumps({"type": "execute_code", "params": {"code": code}})
|
||||
s.sendall(payload.encode("utf-8"))
|
||||
buf = b""
|
||||
while True:
|
||||
try:
|
||||
chunk = s.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buf += chunk
|
||||
try:
|
||||
json.loads(buf.decode("utf-8"))
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except socket.timeout:
|
||||
break
|
||||
s.close()
|
||||
return json.loads(buf.decode("utf-8"))
|
||||
|
||||
## Common bpy Patterns
|
||||
|
||||
### Clear scene
|
||||
bpy.ops.object.select_all(action='SELECT')
|
||||
bpy.ops.object.delete()
|
||||
|
||||
### Add mesh objects
|
||||
bpy.ops.mesh.primitive_uv_sphere_add(radius=1, location=(0, 0, 0))
|
||||
bpy.ops.mesh.primitive_cube_add(size=2, location=(3, 0, 0))
|
||||
bpy.ops.mesh.primitive_cylinder_add(radius=0.5, depth=2, location=(-3, 0, 0))
|
||||
|
||||
### Create and assign material
|
||||
mat = bpy.data.materials.new(name="MyMat")
|
||||
mat.use_nodes = True
|
||||
bsdf = mat.node_tree.nodes.get("Principled BSDF")
|
||||
bsdf.inputs["Base Color"].default_value = (R, G, B, 1.0)
|
||||
bsdf.inputs["Roughness"].default_value = 0.3
|
||||
bsdf.inputs["Metallic"].default_value = 0.0
|
||||
obj.data.materials.append(mat)
|
||||
|
||||
### Keyframe animation
|
||||
obj.location = (0, 0, 0)
|
||||
obj.keyframe_insert(data_path="location", frame=1)
|
||||
obj.location = (0, 0, 3)
|
||||
obj.keyframe_insert(data_path="location", frame=60)
|
||||
|
||||
### Render to file
|
||||
bpy.context.scene.render.filepath = "/tmp/render.png"
|
||||
bpy.context.scene.render.engine = 'CYCLES'
|
||||
bpy.ops.render.render(write_still=True)
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- Must check socket is open before running (nc -z localhost 9876)
|
||||
- Addon server must be started inside Blender each session (N-panel > BlenderMCP > Connect)
|
||||
- Break complex scenes into multiple smaller execute_code calls to avoid timeouts
|
||||
- Render output path must be absolute (/tmp/...) not relative
|
||||
- shade_smooth() requires object to be selected and in object mode
|
||||
@@ -0,0 +1,46 @@
|
||||
# Meme Generation Examples
|
||||
|
||||
## Example 1: Debugging at 2 AM
|
||||
|
||||
**Topic:** debugging production at 2 AM
|
||||
**Template:** this-is-fine
|
||||
|
||||
```bash
|
||||
python generate_meme.py this-is-fine /tmp/meme.png "PRODUCTION IS DOWN" "This is fine"
|
||||
```
|
||||
|
||||
## Example 2: Developer Priorities
|
||||
|
||||
**Topic:** choosing between writing tests and shipping features
|
||||
**Template:** drake
|
||||
|
||||
```bash
|
||||
python generate_meme.py drake /tmp/meme.png "Writing unit tests" "Shipping straight to prod"
|
||||
```
|
||||
|
||||
## Example 3: Exam Stress
|
||||
|
||||
**Topic:** final exam preparation
|
||||
**Template:** two-buttons
|
||||
|
||||
```bash
|
||||
python generate_meme.py two-buttons /tmp/meme.png "Study everything" "Sleep" "Me at midnight"
|
||||
```
|
||||
|
||||
## Example 4: Escalating Solutions
|
||||
|
||||
**Topic:** fixing a CSS bug
|
||||
**Template:** expanding-brain
|
||||
|
||||
```bash
|
||||
python generate_meme.py expanding-brain /tmp/meme.png "Reading the docs" "Stack Overflow" "!important on everything" "Deleting the stylesheet"
|
||||
```
|
||||
|
||||
## Example 5: Hot Take
|
||||
|
||||
**Topic:** tabs vs spaces
|
||||
**Template:** change-my-mind
|
||||
|
||||
```bash
|
||||
python generate_meme.py change-my-mind /tmp/meme.png "Tabs are just thicc spaces"
|
||||
```
|
||||
@@ -0,0 +1,129 @@
|
||||
---
|
||||
name: meme-generation
|
||||
description: Generate real meme images by picking a template and overlaying text with Pillow. Produces actual .png meme files.
|
||||
version: 2.0.0
|
||||
author: adanaleycio
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [creative, memes, humor, images]
|
||||
related_skills: [ascii-art, generative-widgets]
|
||||
category: creative
|
||||
---
|
||||
|
||||
# Meme Generation
|
||||
|
||||
Generate actual meme images from a topic. Picks a template, writes captions, and renders a real .png file with text overlay.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks you to make or generate a meme
|
||||
- User wants a meme about a specific topic, situation, or frustration
|
||||
- User says "meme this" or similar
|
||||
|
||||
## Available Templates
|
||||
|
||||
The script supports **any of the ~100 popular imgflip templates** by name or ID, plus 10 curated templates with hand-tuned text positioning.
|
||||
|
||||
### Curated Templates (custom text placement)
|
||||
|
||||
| ID | Name | Fields | Best for |
|
||||
|----|------|--------|----------|
|
||||
| `this-is-fine` | This is Fine | top, bottom | chaos, denial |
|
||||
| `drake` | Drake Hotline Bling | reject, approve | rejecting/preferring |
|
||||
| `distracted-boyfriend` | Distracted Boyfriend | distraction, current, person | temptation, shifting priorities |
|
||||
| `two-buttons` | Two Buttons | left, right, person | impossible choice |
|
||||
| `expanding-brain` | Expanding Brain | 4 levels | escalating irony |
|
||||
| `change-my-mind` | Change My Mind | statement | hot takes |
|
||||
| `woman-yelling-at-cat` | Woman Yelling at Cat | woman, cat | arguments |
|
||||
| `one-does-not-simply` | One Does Not Simply | top, bottom | deceptively hard things |
|
||||
| `grus-plan` | Gru's Plan | step1-3, realization | plans that backfire |
|
||||
| `batman-slapping-robin` | Batman Slapping Robin | robin, batman | shutting down bad ideas |
|
||||
|
||||
### Dynamic Templates (from imgflip API)
|
||||
|
||||
Any template not in the curated list can be used by name or imgflip ID. These get smart default text positioning (top/bottom for 2-field, evenly spaced for 3+). Search with:
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --search "disaster"
|
||||
```
|
||||
|
||||
## Procedure
|
||||
|
||||
### Mode 1: Classic Template (default)
|
||||
|
||||
1. Read the user's topic and identify the core dynamic (chaos, dilemma, preference, irony, etc.)
|
||||
2. Pick the template that best matches. Use the "Best for" column, or search with `--search`.
|
||||
3. Write short captions for each field (8-12 words max per field, shorter is better).
|
||||
4. Find the skill's script directory:
|
||||
```
|
||||
SKILL_DIR=$(dirname "$(find ~/.hermes/skills -path '*/meme-generation/SKILL.md' 2>/dev/null | head -1)")
|
||||
```
|
||||
5. Run the generator:
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" <template_id> /tmp/meme.png "caption 1" "caption 2" ...
|
||||
```
|
||||
6. Return the image with `MEDIA:/tmp/meme.png`
|
||||
|
||||
### Mode 2: Custom AI Image (when image_generate is available)
|
||||
|
||||
Use this when no classic template fits, or when the user wants something original.
|
||||
|
||||
1. Write the captions first.
|
||||
2. Use `image_generate` to create a scene that matches the meme concept. Do NOT include any text in the image prompt — text will be added by the script. Describe only the visual scene.
|
||||
3. Find the generated image path from the image_generate result URL. Download it to a local path if needed.
|
||||
4. Run the script with `--image` to overlay text, choosing a mode:
|
||||
- **Overlay** (text directly on image, white with black outline):
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --image /path/to/scene.png /tmp/meme.png "top text" "bottom text"
|
||||
```
|
||||
- **Bars** (black bars above/below with white text — cleaner, always readable):
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --image /path/to/scene.png --bars /tmp/meme.png "top text" "bottom text"
|
||||
```
|
||||
Use `--bars` when the image is busy/detailed and text would be hard to read on top of it.
|
||||
5. **Verify with vision** (if `vision_analyze` is available): Check the result looks good:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/meme.png", question="Is the text legible and well-positioned? Does the meme work visually?")
|
||||
```
|
||||
If the vision model flags issues (text hard to read, bad placement, etc.), try the other mode (switch between overlay and bars) or regenerate the scene.
|
||||
6. Return the image with `MEDIA:/tmp/meme.png`
|
||||
|
||||
## Examples
|
||||
|
||||
**"debugging production at 2 AM":**
|
||||
```bash
|
||||
python generate_meme.py this-is-fine /tmp/meme.png "SERVERS ARE ON FIRE" "This is fine"
|
||||
```
|
||||
|
||||
**"choosing between sleep and one more episode":**
|
||||
```bash
|
||||
python generate_meme.py drake /tmp/meme.png "Getting 8 hours of sleep" "One more episode at 3 AM"
|
||||
```
|
||||
|
||||
**"the stages of a Monday morning":**
|
||||
```bash
|
||||
python generate_meme.py expanding-brain /tmp/meme.png "Setting an alarm" "Setting 5 alarms" "Sleeping through all alarms" "Working from bed"
|
||||
```
|
||||
|
||||
## Listing Templates
|
||||
|
||||
To see all available templates:
|
||||
```bash
|
||||
python generate_meme.py --list
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- Keep captions SHORT. Memes with long text look terrible.
|
||||
- Match the number of text arguments to the template's field count.
|
||||
- Pick the template that fits the joke structure, not just the topic.
|
||||
- Do not generate hateful, abusive, or personally targeted content.
|
||||
- The script caches template images in `scripts/.cache/` after first download.
|
||||
|
||||
## Verification
|
||||
|
||||
The output is correct if:
|
||||
- A .png file was created at the output path
|
||||
- Text is legible (white with black outline) on the template
|
||||
- The joke lands — caption matches the template's intended structure
|
||||
- File can be delivered via MEDIA: path
|
||||
@@ -0,0 +1 @@
|
||||
.cache/
|
||||
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate a meme image by overlaying text on a template.
|
||||
|
||||
Usage:
|
||||
python generate_meme.py <template_id_or_name> <output_path> <text1> [text2] [text3] [text4]
|
||||
|
||||
Example:
|
||||
python generate_meme.py drake /tmp/meme.png "Writing tests" "Shipping to prod and hoping"
|
||||
python generate_meme.py "Disaster Girl" /tmp/meme.png "Top text" "Bottom text"
|
||||
python generate_meme.py --list # show curated templates
|
||||
python generate_meme.py --search "distracted" # search all imgflip templates
|
||||
|
||||
Templates with custom text positioning are in templates.json (10 curated).
|
||||
Any of the ~100 popular imgflip templates can also be used by name or ID —
|
||||
unknown templates get smart default text positioning based on their box_count.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import requests as _requests
|
||||
except ImportError:
|
||||
_requests = None
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
TEMPLATES_FILE = SCRIPT_DIR / "templates.json"
|
||||
CACHE_DIR = SCRIPT_DIR / ".cache"
|
||||
IMGFLIP_API = "https://api.imgflip.com/get_memes"
|
||||
IMGFLIP_CACHE_FILE = CACHE_DIR / "imgflip_memes.json"
|
||||
IMGFLIP_CACHE_MAX_AGE = 86400 # 24 hours
|
||||
|
||||
|
||||
def _fetch_url(url: str, timeout: int = 15) -> bytes:
|
||||
"""Fetch URL content, using requests if available, else urllib."""
|
||||
if _requests is not None:
|
||||
resp = _requests.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
import urllib.request
|
||||
return urllib.request.urlopen(url, timeout=timeout).read()
|
||||
|
||||
|
||||
def load_curated_templates() -> dict:
|
||||
"""Load templates with hand-tuned text field positions."""
|
||||
with open(TEMPLATES_FILE) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _default_fields(box_count: int) -> list:
|
||||
"""Generate sensible default text field positions for unknown templates."""
|
||||
if box_count <= 0:
|
||||
box_count = 2
|
||||
if box_count == 1:
|
||||
return [{"name": "text", "x_pct": 0.5, "y_pct": 0.5, "w_pct": 0.90, "align": "center"}]
|
||||
if box_count == 2:
|
||||
return [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"},
|
||||
]
|
||||
# 3+: evenly space vertically
|
||||
fields = []
|
||||
for i in range(box_count):
|
||||
y = 0.08 + (0.84 * i / (box_count - 1)) if box_count > 1 else 0.5
|
||||
fields.append({
|
||||
"name": f"text{i+1}",
|
||||
"x_pct": 0.5,
|
||||
"y_pct": round(y, 2),
|
||||
"w_pct": 0.90,
|
||||
"align": "center",
|
||||
})
|
||||
return fields
|
||||
|
||||
|
||||
def fetch_imgflip_templates() -> list:
|
||||
"""Fetch popular meme templates from imgflip API. Cached for 24h."""
|
||||
import time
|
||||
|
||||
CACHE_DIR.mkdir(exist_ok=True)
|
||||
# Check cache
|
||||
if IMGFLIP_CACHE_FILE.exists():
|
||||
age = time.time() - IMGFLIP_CACHE_FILE.stat().st_mtime
|
||||
if age < IMGFLIP_CACHE_MAX_AGE:
|
||||
with open(IMGFLIP_CACHE_FILE) as f:
|
||||
return json.load(f)
|
||||
|
||||
try:
|
||||
data = json.loads(_fetch_url(IMGFLIP_API))
|
||||
memes = data.get("data", {}).get("memes", [])
|
||||
with open(IMGFLIP_CACHE_FILE, "w") as f:
|
||||
json.dump(memes, f)
|
||||
return memes
|
||||
except Exception as e:
|
||||
# If fetch fails and we have stale cache, use it
|
||||
if IMGFLIP_CACHE_FILE.exists():
|
||||
with open(IMGFLIP_CACHE_FILE) as f:
|
||||
return json.load(f)
|
||||
print(f"Warning: could not fetch imgflip templates: {e}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
def _slugify(name: str) -> str:
|
||||
"""Convert a template name to a slug for matching."""
|
||||
return name.lower().replace(" ", "-").replace("'", "").replace("\"", "")
|
||||
|
||||
|
||||
def resolve_template(identifier: str) -> dict:
|
||||
"""Resolve a template by curated ID, imgflip name, or imgflip ID.
|
||||
|
||||
Returns dict with: name, url, fields, source.
|
||||
"""
|
||||
curated = load_curated_templates()
|
||||
|
||||
# 1. Exact curated ID match
|
||||
if identifier in curated:
|
||||
tmpl = curated[identifier]
|
||||
return {**tmpl, "source": "curated"}
|
||||
|
||||
# 2. Slugified curated match
|
||||
slug = _slugify(identifier)
|
||||
for tid, tmpl in curated.items():
|
||||
if _slugify(tmpl["name"]) == slug or tid == slug:
|
||||
return {**tmpl, "source": "curated"}
|
||||
|
||||
# 3. Search imgflip templates
|
||||
imgflip_memes = fetch_imgflip_templates()
|
||||
slug_lower = slug.lower()
|
||||
id_lower = identifier.strip()
|
||||
|
||||
for meme in imgflip_memes:
|
||||
meme_slug = _slugify(meme["name"])
|
||||
# Check curated first for this imgflip template (custom positioning)
|
||||
for tid, ctmpl in curated.items():
|
||||
if _slugify(ctmpl["name"]) == meme_slug:
|
||||
if meme_slug == slug_lower or meme["id"] == id_lower:
|
||||
return {**ctmpl, "source": "curated"}
|
||||
|
||||
if meme_slug == slug_lower or meme["id"] == id_lower or slug_lower in meme_slug:
|
||||
return {
|
||||
"name": meme["name"],
|
||||
"url": meme["url"],
|
||||
"fields": _default_fields(meme.get("box_count", 2)),
|
||||
"source": "imgflip",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_template_image(url: str) -> Image.Image:
|
||||
"""Download a template image, caching it locally."""
|
||||
CACHE_DIR.mkdir(exist_ok=True)
|
||||
# Use URL hash as cache key
|
||||
cache_name = url.split("/")[-1]
|
||||
cache_path = CACHE_DIR / cache_name
|
||||
|
||||
# Always cache as PNG to avoid JPEG/RGBA conflicts
|
||||
cache_path = cache_path.with_suffix(".png")
|
||||
|
||||
if cache_path.exists():
|
||||
return Image.open(cache_path).convert("RGBA")
|
||||
|
||||
data = _fetch_url(url)
|
||||
img = Image.open(BytesIO(data)).convert("RGBA")
|
||||
img.save(cache_path, "PNG")
|
||||
return img
|
||||
|
||||
|
||||
def find_font(size: int) -> ImageFont.FreeTypeFont:
|
||||
"""Find a bold font for meme text. Tries Impact, then falls back."""
|
||||
candidates = [
|
||||
"/usr/share/fonts/truetype/msttcorefonts/Impact.ttf",
|
||||
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation-sans/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/dejavu-sans/DejaVuSans-Bold.ttf",
|
||||
"/System/Library/Fonts/Helvetica.ttc",
|
||||
"/System/Library/Fonts/SFCompact.ttf",
|
||||
]
|
||||
for path in candidates:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
return ImageFont.truetype(path, size)
|
||||
except (OSError, IOError):
|
||||
continue
|
||||
# Last resort: Pillow default
|
||||
try:
|
||||
return ImageFont.truetype("DejaVuSans-Bold", size)
|
||||
except (OSError, IOError):
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap_text(text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
|
||||
"""Word-wrap text to fit within max_width pixels. Never breaks mid-word."""
|
||||
words = text.split()
|
||||
if not words:
|
||||
return text
|
||||
lines = []
|
||||
current_line = words[0]
|
||||
for word in words[1:]:
|
||||
test_line = current_line + " " + word
|
||||
if font.getlength(test_line) <= max_width:
|
||||
current_line = test_line
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word
|
||||
lines.append(current_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def draw_outlined_text(
|
||||
draw: ImageDraw.ImageDraw,
|
||||
text: str,
|
||||
x: int,
|
||||
y: int,
|
||||
font_size: int,
|
||||
max_width: int,
|
||||
align: str = "center",
|
||||
):
|
||||
"""Draw white text with black outline, auto-scaled to fit max_width."""
|
||||
# Auto-scale: reduce font size until text fits reasonably
|
||||
size = font_size
|
||||
while size > 12:
|
||||
font = find_font(size)
|
||||
wrapped = _wrap_text(text, font, max_width)
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align=align)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
line_count = wrapped.count("\n") + 1
|
||||
# Accept if width fits and not too many lines
|
||||
if text_w <= max_width * 1.05 and line_count <= 4:
|
||||
break
|
||||
size -= 2
|
||||
else:
|
||||
font = find_font(size)
|
||||
wrapped = _wrap_text(text, font, max_width)
|
||||
|
||||
# Measure total text block
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align=align)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
|
||||
# Center horizontally at x, vertically at y
|
||||
tx = x - text_w // 2
|
||||
ty = y - text_h // 2
|
||||
|
||||
# Draw outline (black border)
|
||||
outline_range = max(2, font.size // 18)
|
||||
for dx in range(-outline_range, outline_range + 1):
|
||||
for dy in range(-outline_range, outline_range + 1):
|
||||
if dx == 0 and dy == 0:
|
||||
continue
|
||||
draw.multiline_text(
|
||||
(tx + dx, ty + dy), wrapped, font=font, fill="black", align=align
|
||||
)
|
||||
# Draw main text (white)
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align=align)
|
||||
|
||||
|
||||
def _overlay_on_image(img: Image.Image, texts: list, fields: list) -> Image.Image:
|
||||
"""Overlay meme text directly on an image using field positions."""
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
base_font_size = max(16, min(w, h) // 12)
|
||||
|
||||
for i, field in enumerate(fields):
|
||||
if i >= len(texts):
|
||||
break
|
||||
text = texts[i].strip()
|
||||
if not text:
|
||||
continue
|
||||
fx = int(field["x_pct"] * w)
|
||||
fy = int(field["y_pct"] * h)
|
||||
fw = int(field["w_pct"] * w)
|
||||
draw_outlined_text(draw, text, fx, fy, base_font_size, fw, field.get("align", "center"))
|
||||
return img
|
||||
|
||||
|
||||
def _add_bars(img: Image.Image, texts: list) -> Image.Image:
|
||||
"""Add black bars with white text above/below the image.
|
||||
|
||||
Distributes texts across bars: first text on top bar, last text on
|
||||
bottom bar, any middle texts overlaid on the image center.
|
||||
"""
|
||||
w, h = img.size
|
||||
bar_font_size = max(20, w // 16)
|
||||
font = find_font(bar_font_size)
|
||||
padding = bar_font_size // 2
|
||||
|
||||
top_text = texts[0].strip() if texts else ""
|
||||
bottom_text = texts[-1].strip() if len(texts) > 1 else ""
|
||||
middle_texts = [t.strip() for t in texts[1:-1]] if len(texts) > 2 else []
|
||||
|
||||
def _measure_bar(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
wrapped = _wrap_text(text, font, int(w * 0.92))
|
||||
bbox = ImageDraw.Draw(Image.new("RGB", (1, 1))).multiline_textbbox(
|
||||
(0, 0), wrapped, font=font, align="center"
|
||||
)
|
||||
return (bbox[3] - bbox[1]) + padding * 2
|
||||
|
||||
top_h = _measure_bar(top_text)
|
||||
bottom_h = _measure_bar(bottom_text)
|
||||
new_h = h + top_h + bottom_h
|
||||
|
||||
canvas = Image.new("RGB", (w, new_h), (0, 0, 0))
|
||||
canvas.paste(img.convert("RGB"), (0, top_h))
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
|
||||
if top_text:
|
||||
wrapped = _wrap_text(top_text, font, int(w * 0.92))
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align="center")
|
||||
tw = bbox[2] - bbox[0]
|
||||
th = bbox[3] - bbox[1]
|
||||
tx = (w - tw) // 2
|
||||
ty = (top_h - th) // 2
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align="center")
|
||||
|
||||
if bottom_text:
|
||||
wrapped = _wrap_text(bottom_text, font, int(w * 0.92))
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align="center")
|
||||
tw = bbox[2] - bbox[0]
|
||||
th = bbox[3] - bbox[1]
|
||||
tx = (w - tw) // 2
|
||||
ty = top_h + h + (bottom_h - th) // 2
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align="center")
|
||||
|
||||
# Overlay any middle texts centered on the image
|
||||
if middle_texts:
|
||||
mid_fields = _default_fields(len(middle_texts))
|
||||
# Shift y positions to account for top bar offset
|
||||
for field in mid_fields:
|
||||
field["y_pct"] = (top_h + field["y_pct"] * h) / new_h
|
||||
field["w_pct"] = 0.90
|
||||
_overlay_on_image(canvas, middle_texts, mid_fields)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def generate_meme(template_id: str, texts: list[str], output_path: str) -> str:
|
||||
"""Generate a meme from a template and save it. Returns the path."""
|
||||
tmpl = resolve_template(template_id)
|
||||
|
||||
if tmpl is None:
|
||||
print(f"Unknown template: {template_id}", file=sys.stderr)
|
||||
print("Use --list to see curated templates or --search to find imgflip templates.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
fields = tmpl["fields"]
|
||||
print(f"Using template: {tmpl['name']} ({tmpl['source']}, {len(fields)} fields)", file=sys.stderr)
|
||||
|
||||
img = get_template_image(tmpl["url"])
|
||||
img = _overlay_on_image(img, texts, fields)
|
||||
|
||||
output = Path(output_path)
|
||||
if output.suffix.lower() in (".jpg", ".jpeg"):
|
||||
img = img.convert("RGB")
|
||||
img.save(str(output), quality=95)
|
||||
return str(output)
|
||||
|
||||
|
||||
def generate_from_image(
|
||||
image_path: str, texts: list[str], output_path: str, use_bars: bool = False
|
||||
) -> str:
|
||||
"""Generate a meme from a custom image (e.g. AI-generated). Returns the path."""
|
||||
img = Image.open(image_path).convert("RGBA")
|
||||
print(f"Custom image: {img.size[0]}x{img.size[1]}, {len(texts)} text(s), mode={'bars' if use_bars else 'overlay'}", file=sys.stderr)
|
||||
|
||||
if use_bars:
|
||||
result = _add_bars(img, texts)
|
||||
else:
|
||||
fields = _default_fields(len(texts))
|
||||
result = _overlay_on_image(img, texts, fields)
|
||||
|
||||
output = Path(output_path)
|
||||
if output.suffix.lower() in (".jpg", ".jpeg"):
|
||||
result = result.convert("RGB")
|
||||
result.save(str(output), quality=95)
|
||||
return str(output)
|
||||
|
||||
|
||||
def list_templates():
|
||||
"""Print curated templates with custom positioning."""
|
||||
templates = load_curated_templates()
|
||||
print(f"{'ID':<25} {'Name':<30} {'Fields':<8} Best for")
|
||||
print("-" * 90)
|
||||
for tid, tmpl in sorted(templates.items()):
|
||||
fields = len(tmpl["fields"])
|
||||
print(f"{tid:<25} {tmpl['name']:<30} {fields:<8} {tmpl['best_for']}")
|
||||
print(f"\n{len(templates)} curated templates with custom text positioning.")
|
||||
print("Use --search to find any of the ~100 popular imgflip templates.")
|
||||
|
||||
|
||||
def search_templates(query: str):
|
||||
"""Search imgflip templates by name."""
|
||||
imgflip_memes = fetch_imgflip_templates()
|
||||
curated = load_curated_templates()
|
||||
curated_slugs = {_slugify(t["name"]) for t in curated.values()}
|
||||
query_lower = query.lower()
|
||||
|
||||
matches = []
|
||||
for meme in imgflip_memes:
|
||||
if query_lower in meme["name"].lower():
|
||||
slug = _slugify(meme["name"])
|
||||
has_custom = "curated" if slug in curated_slugs else "default"
|
||||
matches.append((meme["name"], meme["id"], meme.get("box_count", 2), has_custom))
|
||||
|
||||
if not matches:
|
||||
print(f"No templates found matching '{query}'")
|
||||
return
|
||||
|
||||
print(f"{'Name':<40} {'ID':<12} {'Fields':<8} Positioning")
|
||||
print("-" * 75)
|
||||
for name, mid, boxes, positioning in matches:
|
||||
print(f"{name:<40} {mid:<12} {boxes:<8} {positioning}")
|
||||
print(f"\n{len(matches)} template(s) found. Use the name or ID as the first argument.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: generate_meme.py <template_id_or_name> <output_path> <text1> [text2] ...")
|
||||
print(" generate_meme.py --image <path> [--bars] <output_path> <text1> [text2] ...")
|
||||
print(" generate_meme.py --list # curated templates")
|
||||
print(" generate_meme.py --search <query> # search all imgflip templates")
|
||||
sys.exit(1)
|
||||
|
||||
if sys.argv[1] == "--list":
|
||||
list_templates()
|
||||
sys.exit(0)
|
||||
|
||||
if sys.argv[1] == "--search":
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: generate_meme.py --search <query>")
|
||||
sys.exit(1)
|
||||
search_templates(sys.argv[2])
|
||||
sys.exit(0)
|
||||
|
||||
if sys.argv[1] == "--image":
|
||||
# Custom image mode: --image <path> [--bars] <output> <text1> ...
|
||||
args = sys.argv[2:]
|
||||
if len(args) < 3:
|
||||
print("Usage: generate_meme.py --image <image_path> [--bars] <output_path> <text1> ...")
|
||||
sys.exit(1)
|
||||
image_path = args.pop(0)
|
||||
use_bars = False
|
||||
if args and args[0] == "--bars":
|
||||
use_bars = True
|
||||
args.pop(0)
|
||||
if len(args) < 2:
|
||||
print("Need at least: output_path and one text argument")
|
||||
sys.exit(1)
|
||||
output_path = args.pop(0)
|
||||
result = generate_from_image(image_path, args, output_path, use_bars=use_bars)
|
||||
print(f"Meme saved to: {result}")
|
||||
sys.exit(0)
|
||||
|
||||
if len(sys.argv) < 4:
|
||||
print("Need at least: template_id_or_name, output_path, and one text argument")
|
||||
sys.exit(1)
|
||||
|
||||
template_id = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
texts = sys.argv[3:]
|
||||
|
||||
result = generate_meme(template_id, texts, output_path)
|
||||
print(f"Meme saved to: {result}")
|
||||
@@ -0,0 +1,97 @@
|
||||
{
|
||||
"this-is-fine": {
|
||||
"name": "This is Fine",
|
||||
"url": "https://i.imgflip.com/wxica.jpg",
|
||||
"best_for": "chaos, denial, pretending things are okay",
|
||||
"fields": [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"}
|
||||
]
|
||||
},
|
||||
"drake": {
|
||||
"name": "Drake Hotline Bling",
|
||||
"url": "https://i.imgflip.com/30b1gx.jpg",
|
||||
"best_for": "rejecting one thing, preferring another",
|
||||
"fields": [
|
||||
{"name": "reject", "x_pct": 0.73, "y_pct": 0.25, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "approve", "x_pct": 0.73, "y_pct": 0.75, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"distracted-boyfriend": {
|
||||
"name": "Distracted Boyfriend",
|
||||
"url": "https://i.imgflip.com/1ur9b0.jpg",
|
||||
"best_for": "distraction, shifting priorities, temptation",
|
||||
"fields": [
|
||||
{"name": "distraction", "x_pct": 0.18, "y_pct": 0.90, "w_pct": 0.30, "align": "center"},
|
||||
{"name": "current", "x_pct": 0.55, "y_pct": 0.90, "w_pct": 0.30, "align": "center"},
|
||||
{"name": "person", "x_pct": 0.82, "y_pct": 0.90, "w_pct": 0.30, "align": "center"}
|
||||
]
|
||||
},
|
||||
"two-buttons": {
|
||||
"name": "Two Buttons",
|
||||
"url": "https://i.imgflip.com/1g8my4.jpg",
|
||||
"best_for": "impossible choice, dilemma between two options",
|
||||
"fields": [
|
||||
{"name": "left_button", "x_pct": 0.30, "y_pct": 0.20, "w_pct": 0.28, "align": "center"},
|
||||
{"name": "right_button", "x_pct": 0.62, "y_pct": 0.12, "w_pct": 0.28, "align": "center"},
|
||||
{"name": "person", "x_pct": 0.5, "y_pct": 0.85, "w_pct": 0.90, "align": "center"}
|
||||
]
|
||||
},
|
||||
"expanding-brain": {
|
||||
"name": "Expanding Brain",
|
||||
"url": "https://i.imgflip.com/1jwhww.jpg",
|
||||
"best_for": "escalating irony, increasingly absurd ideas",
|
||||
"fields": [
|
||||
{"name": "level1", "x_pct": 0.25, "y_pct": 0.12, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level2", "x_pct": 0.25, "y_pct": 0.38, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level3", "x_pct": 0.25, "y_pct": 0.63, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level4", "x_pct": 0.25, "y_pct": 0.88, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"change-my-mind": {
|
||||
"name": "Change My Mind",
|
||||
"url": "https://i.imgflip.com/24y43o.jpg",
|
||||
"best_for": "strong or ironic opinion, controversial take",
|
||||
"fields": [
|
||||
{"name": "statement", "x_pct": 0.58, "y_pct": 0.78, "w_pct": 0.35, "align": "center"}
|
||||
]
|
||||
},
|
||||
"woman-yelling-at-cat": {
|
||||
"name": "Woman Yelling at Cat",
|
||||
"url": "https://i.imgflip.com/345v97.jpg",
|
||||
"best_for": "argument, blame, misunderstanding",
|
||||
"fields": [
|
||||
{"name": "woman", "x_pct": 0.27, "y_pct": 0.10, "w_pct": 0.50, "align": "center"},
|
||||
{"name": "cat", "x_pct": 0.76, "y_pct": 0.10, "w_pct": 0.44, "align": "center"}
|
||||
]
|
||||
},
|
||||
"one-does-not-simply": {
|
||||
"name": "One Does Not Simply",
|
||||
"url": "https://i.imgflip.com/1bij.jpg",
|
||||
"best_for": "something that sounds easy but is actually hard",
|
||||
"fields": [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"}
|
||||
]
|
||||
},
|
||||
"grus-plan": {
|
||||
"name": "Gru's Plan",
|
||||
"url": "https://i.imgflip.com/26jxvs.jpg",
|
||||
"best_for": "a plan that backfires, unexpected consequence",
|
||||
"fields": [
|
||||
{"name": "step1", "x_pct": 0.5, "y_pct": 0.05, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "step2", "x_pct": 0.5, "y_pct": 0.30, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "step3", "x_pct": 0.5, "y_pct": 0.55, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "realization", "x_pct": 0.5, "y_pct": 0.80, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"batman-slapping-robin": {
|
||||
"name": "Batman Slapping Robin",
|
||||
"url": "https://i.imgflip.com/9ehk.jpg",
|
||||
"best_for": "shutting down a bad idea, correcting someone",
|
||||
"fields": [
|
||||
{"name": "robin", "x_pct": 0.28, "y_pct": 0.08, "w_pct": 0.50, "align": "center"},
|
||||
{"name": "batman", "x_pct": 0.72, "y_pct": 0.08, "w_pct": 0.50, "align": "center"}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
# MCP
|
||||
|
||||
Skills for building, testing, and deploying MCP (Model Context Protocol) servers.
|
||||
@@ -0,0 +1,299 @@
|
||||
---
|
||||
name: fastmcp
|
||||
description: Build, test, inspect, install, and deploy MCP servers with FastMCP in Python. Use when creating a new MCP server, wrapping an API or database as MCP tools, exposing resources or prompts, or preparing a FastMCP server for Claude Code, Cursor, or HTTP deployment.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [MCP, FastMCP, Python, Tools, Resources, Prompts, Deployment]
|
||||
homepage: https://gofastmcp.com
|
||||
related_skills: [native-mcp, mcporter]
|
||||
prerequisites:
|
||||
commands: [python3]
|
||||
---
|
||||
|
||||
# FastMCP
|
||||
|
||||
Build MCP servers in Python with FastMCP, validate them locally, install them into MCP clients, and deploy them as HTTP endpoints.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when the task is to:
|
||||
|
||||
- create a new MCP server in Python
|
||||
- wrap an API, database, CLI, or file-processing workflow as MCP tools
|
||||
- expose resources or prompts in addition to tools
|
||||
- smoke-test a server with the FastMCP CLI before wiring it into Hermes or another client
|
||||
- install a server into Claude Code, Claude Desktop, Cursor, or a similar MCP client
|
||||
- prepare a FastMCP server repo for HTTP deployment
|
||||
|
||||
Use `native-mcp` when the server already exists and only needs to be connected to Hermes. Use `mcporter` when the goal is ad-hoc CLI access to an existing MCP server instead of building one.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install FastMCP in the working environment first:
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
For the API template, install `httpx` if it is not already present:
|
||||
|
||||
```bash
|
||||
pip install httpx
|
||||
```
|
||||
|
||||
## Included Files
|
||||
|
||||
### Templates
|
||||
|
||||
- `templates/api_wrapper.py` - REST API wrapper with auth header support
|
||||
- `templates/database_server.py` - read-only SQLite query server
|
||||
- `templates/file_processor.py` - text-file inspection and search server
|
||||
|
||||
### Scripts
|
||||
|
||||
- `scripts/scaffold_fastmcp.py` - copy a starter template and replace the server name placeholder
|
||||
|
||||
### References
|
||||
|
||||
- `references/fastmcp-cli.md` - FastMCP CLI workflow, installation targets, and deployment checks
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Pick the Smallest Viable Server Shape
|
||||
|
||||
Choose the narrowest useful surface area first:
|
||||
|
||||
- API wrapper: start with 1-3 high-value endpoints, not the whole API
|
||||
- database server: expose read-only introspection and a constrained query path
|
||||
- file processor: expose deterministic operations with explicit path arguments
|
||||
- prompts/resources: add only when the client needs reusable prompt templates or discoverable documents
|
||||
|
||||
Prefer a thin server with good names, docstrings, and schemas over a large server with vague tools.
|
||||
|
||||
### 2. Scaffold from a Template
|
||||
|
||||
Copy a template directly or use the scaffold helper:
|
||||
|
||||
```bash
|
||||
python ~/.hermes/skills/mcp/fastmcp/scripts/scaffold_fastmcp.py \
|
||||
--template api_wrapper \
|
||||
--name "Acme API" \
|
||||
--output ./acme_server.py
|
||||
```
|
||||
|
||||
Available templates:
|
||||
|
||||
```bash
|
||||
python ~/.hermes/skills/mcp/fastmcp/scripts/scaffold_fastmcp.py --list
|
||||
```
|
||||
|
||||
If copying manually, replace `__SERVER_NAME__` with a real server name.
|
||||
|
||||
### 3. Implement Tools First
|
||||
|
||||
Start with `@mcp.tool` functions before adding resources or prompts.
|
||||
|
||||
Rules for tool design:
|
||||
|
||||
- Give every tool a concrete verb-based name
|
||||
- Write docstrings as user-facing tool descriptions
|
||||
- Keep parameters explicit and typed
|
||||
- Return structured JSON-safe data where possible
|
||||
- Validate unsafe inputs early
|
||||
- Prefer read-only behavior by default for first versions
|
||||
|
||||
Good tool examples:
|
||||
|
||||
- `get_customer`
|
||||
- `search_tickets`
|
||||
- `describe_table`
|
||||
- `summarize_text_file`
|
||||
|
||||
Weak tool examples:
|
||||
|
||||
- `run`
|
||||
- `process`
|
||||
- `do_thing`
|
||||
|
||||
### 4. Add Resources and Prompts Only When They Help
|
||||
|
||||
Add `@mcp.resource` when the client benefits from fetching stable read-only content such as schemas, policy docs, or generated reports.
|
||||
|
||||
Add `@mcp.prompt` when the server should provide a reusable prompt template for a known workflow.
|
||||
|
||||
Do not turn every document into a prompt. Prefer:
|
||||
|
||||
- tools for actions
|
||||
- resources for data/document retrieval
|
||||
- prompts for reusable LLM instructions
|
||||
|
||||
### 5. Test the Server Before Integrating It Anywhere
|
||||
|
||||
Use the FastMCP CLI for local validation:
|
||||
|
||||
```bash
|
||||
fastmcp inspect acme_server.py:mcp
|
||||
fastmcp list acme_server.py --json
|
||||
fastmcp call acme_server.py search_resources query=router limit=5 --json
|
||||
```
|
||||
|
||||
For fast iterative debugging, run the server locally:
|
||||
|
||||
```bash
|
||||
fastmcp run acme_server.py:mcp
|
||||
```
|
||||
|
||||
To test HTTP transport locally:
|
||||
|
||||
```bash
|
||||
fastmcp run acme_server.py:mcp --transport http --host 127.0.0.1 --port 8000
|
||||
fastmcp list http://127.0.0.1:8000/mcp --json
|
||||
fastmcp call http://127.0.0.1:8000/mcp search_resources query=router --json
|
||||
```
|
||||
|
||||
Always run at least one real `fastmcp call` against each new tool before claiming the server works.
|
||||
|
||||
### 6. Install into a Client When Local Validation Passes
|
||||
|
||||
FastMCP can register the server with supported MCP clients:
|
||||
|
||||
```bash
|
||||
fastmcp install claude-code acme_server.py
|
||||
fastmcp install claude-desktop acme_server.py
|
||||
fastmcp install cursor acme_server.py -e .
|
||||
```
|
||||
|
||||
Use `fastmcp discover` to inspect named MCP servers already configured on the machine.
|
||||
|
||||
When the goal is Hermes integration, either:
|
||||
|
||||
- configure the server in `~/.hermes/config.yaml` using the `native-mcp` skill, or
|
||||
- keep using FastMCP CLI commands during development until the interface stabilizes
|
||||
|
||||
### 7. Deploy After the Local Contract Is Stable
|
||||
|
||||
For managed hosting, Prefect Horizon is the path FastMCP documents most directly. Before deployment:
|
||||
|
||||
```bash
|
||||
fastmcp inspect acme_server.py:mcp
|
||||
```
|
||||
|
||||
Make sure the repo contains:
|
||||
|
||||
- a Python file with the FastMCP server object
|
||||
- `requirements.txt` or `pyproject.toml`
|
||||
- any environment-variable documentation needed for deployment
|
||||
|
||||
For generic HTTP hosting, validate the HTTP transport locally first, then deploy on any Python-compatible platform that can expose the server port.
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### API Wrapper Pattern
|
||||
|
||||
Use when exposing a REST or HTTP API as MCP tools.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- one read path
|
||||
- one list/search path
|
||||
- optional health check
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- keep auth in environment variables, not hardcoded
|
||||
- centralize request logic in one helper
|
||||
- surface API errors with concise context
|
||||
- normalize inconsistent upstream payloads before returning them
|
||||
|
||||
Start from `templates/api_wrapper.py`.
|
||||
|
||||
### Database Pattern
|
||||
|
||||
Use when exposing safe query and inspection capabilities.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- `list_tables`
|
||||
- `describe_table`
|
||||
- one constrained read query tool
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- default to read-only DB access
|
||||
- reject non-`SELECT` SQL in early versions
|
||||
- limit row counts
|
||||
- return rows plus column names
|
||||
|
||||
Start from `templates/database_server.py`.
|
||||
|
||||
### File Processor Pattern
|
||||
|
||||
Use when the server needs to inspect or transform files on demand.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- summarize file contents
|
||||
- search within files
|
||||
- extract deterministic metadata
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- accept explicit file paths
|
||||
- check for missing files and encoding failures
|
||||
- cap previews and result counts
|
||||
- avoid shelling out unless a specific external tool is required
|
||||
|
||||
Start from `templates/file_processor.py`.
|
||||
|
||||
## Quality Bar
|
||||
|
||||
Before handing off a FastMCP server, verify all of the following:
|
||||
|
||||
- server imports cleanly
|
||||
- `fastmcp inspect <file.py:mcp>` succeeds
|
||||
- `fastmcp list <server spec> --json` succeeds
|
||||
- every new tool has at least one real `fastmcp call`
|
||||
- environment variables are documented
|
||||
- the tool surface is small enough to understand without guesswork
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### FastMCP command missing
|
||||
|
||||
Install the package in the active environment:
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
### `fastmcp inspect` fails
|
||||
|
||||
Check that:
|
||||
|
||||
- the file imports without side effects that crash
|
||||
- the FastMCP instance is named correctly in `<file.py:object>`
|
||||
- optional dependencies from the template are installed
|
||||
|
||||
### Tool works in Python but not through CLI
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
fastmcp list server.py --json
|
||||
fastmcp call server.py your_tool_name --json
|
||||
```
|
||||
|
||||
This usually exposes naming mismatches, missing required arguments, or non-serializable return values.
|
||||
|
||||
### Hermes cannot see the deployed server
|
||||
|
||||
The server-building part may be correct while the Hermes config is not. Load the `native-mcp` skill and configure the server in `~/.hermes/config.yaml`, then restart Hermes.
|
||||
|
||||
## References
|
||||
|
||||
For CLI details, install targets, and deployment checks, read `references/fastmcp-cli.md`.
|
||||
@@ -0,0 +1,110 @@
|
||||
# FastMCP CLI Reference
|
||||
|
||||
Use this file when the task needs exact FastMCP CLI workflows rather than the higher-level guidance in `SKILL.md`.
|
||||
|
||||
## Install and Verify
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
FastMCP documents `pip install fastmcp` and `fastmcp version` as the baseline installation and verification path.
|
||||
|
||||
## Run a Server
|
||||
|
||||
Run a server object from a Python file:
|
||||
|
||||
```bash
|
||||
fastmcp run server.py:mcp
|
||||
```
|
||||
|
||||
Run the same server over HTTP:
|
||||
|
||||
```bash
|
||||
fastmcp run server.py:mcp --transport http --host 127.0.0.1 --port 8000
|
||||
```
|
||||
|
||||
## Inspect a Server
|
||||
|
||||
Inspect what FastMCP will expose:
|
||||
|
||||
```bash
|
||||
fastmcp inspect server.py:mcp
|
||||
```
|
||||
|
||||
This is also the check FastMCP recommends before deploying to Prefect Horizon.
|
||||
|
||||
## List and Call Tools
|
||||
|
||||
List tools from a Python file:
|
||||
|
||||
```bash
|
||||
fastmcp list server.py --json
|
||||
```
|
||||
|
||||
List tools from an HTTP endpoint:
|
||||
|
||||
```bash
|
||||
fastmcp list http://127.0.0.1:8000/mcp --json
|
||||
```
|
||||
|
||||
Call a tool with key-value arguments:
|
||||
|
||||
```bash
|
||||
fastmcp call server.py search_resources query=router limit=5 --json
|
||||
```
|
||||
|
||||
Call a tool with a full JSON input payload:
|
||||
|
||||
```bash
|
||||
fastmcp call server.py create_item '{"name": "Widget", "tags": ["sale"]}' --json
|
||||
```
|
||||
|
||||
## Discover Named MCP Servers
|
||||
|
||||
Find named servers already configured in local MCP-aware tools:
|
||||
|
||||
```bash
|
||||
fastmcp discover
|
||||
```
|
||||
|
||||
FastMCP documents name-based resolution for Claude Desktop, Claude Code, Cursor, Gemini, Goose, and `./mcp.json`.
|
||||
|
||||
## Install into MCP Clients
|
||||
|
||||
Register a server with common clients:
|
||||
|
||||
```bash
|
||||
fastmcp install claude-code server.py
|
||||
fastmcp install claude-desktop server.py
|
||||
fastmcp install cursor server.py -e .
|
||||
```
|
||||
|
||||
FastMCP notes that client installs run in isolated environments, so declare dependencies explicitly when needed with flags such as `--with`, `--env-file`, or editable installs.
|
||||
|
||||
## Deployment Checks
|
||||
|
||||
### Prefect Horizon
|
||||
|
||||
Before pushing to Horizon:
|
||||
|
||||
```bash
|
||||
fastmcp inspect server.py:mcp
|
||||
```
|
||||
|
||||
FastMCP’s Horizon docs expect:
|
||||
|
||||
- a GitHub repo
|
||||
- a Python file containing the FastMCP server object
|
||||
- dependencies declared in `requirements.txt` or `pyproject.toml`
|
||||
- an entrypoint like `main.py:mcp`
|
||||
|
||||
### Generic HTTP Hosting
|
||||
|
||||
Before shipping to any other host:
|
||||
|
||||
1. Start the server locally with HTTP transport.
|
||||
2. Verify `fastmcp list` against the local `/mcp` URL.
|
||||
3. Verify at least one `fastmcp call`.
|
||||
4. Document required environment variables.
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Copy a FastMCP starter template into a working file."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
SKILL_DIR = SCRIPT_DIR.parent
|
||||
TEMPLATE_DIR = SKILL_DIR / "templates"
|
||||
PLACEHOLDER = "__SERVER_NAME__"
|
||||
|
||||
|
||||
def list_templates() -> list[str]:
|
||||
return sorted(path.stem for path in TEMPLATE_DIR.glob("*.py"))
|
||||
|
||||
|
||||
def render_template(template_name: str, server_name: str) -> str:
|
||||
template_path = TEMPLATE_DIR / f"{template_name}.py"
|
||||
if not template_path.exists():
|
||||
available = ", ".join(list_templates())
|
||||
raise SystemExit(f"Unknown template '{template_name}'. Available: {available}")
|
||||
return template_path.read_text(encoding="utf-8").replace(PLACEHOLDER, server_name)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--template", help="Template name without .py suffix")
|
||||
parser.add_argument("--name", help="FastMCP server display name")
|
||||
parser.add_argument("--output", help="Destination Python file path")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite an existing output file")
|
||||
parser.add_argument("--list", action="store_true", help="List available templates and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
for name in list_templates():
|
||||
print(name)
|
||||
return 0
|
||||
|
||||
if not args.template or not args.name or not args.output:
|
||||
parser.error("--template, --name, and --output are required unless --list is used")
|
||||
|
||||
output_path = Path(args.output).expanduser()
|
||||
if output_path.exists() and not args.force:
|
||||
raise SystemExit(f"Refusing to overwrite existing file: {output_path}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(render_template(args.template, args.name), encoding="utf-8")
|
||||
print(f"Wrote {output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.example.com")
|
||||
API_TOKEN = os.getenv("API_TOKEN")
|
||||
REQUEST_TIMEOUT = float(os.getenv("API_TIMEOUT_SECONDS", "20"))
|
||||
|
||||
|
||||
def _headers() -> dict[str, str]:
|
||||
headers = {"Accept": "application/json"}
|
||||
if API_TOKEN:
|
||||
headers["Authorization"] = f"Bearer {API_TOKEN}"
|
||||
return headers
|
||||
|
||||
|
||||
def _request(method: str, path: str, *, params: dict[str, Any] | None = None) -> Any:
|
||||
url = f"{API_BASE_URL.rstrip('/')}/{path.lstrip('/')}"
|
||||
with httpx.Client(timeout=REQUEST_TIMEOUT, headers=_headers()) as client:
|
||||
response = client.request(method, url, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def health_check() -> dict[str, Any]:
|
||||
"""Check whether the upstream API is reachable."""
|
||||
payload = _request("GET", "/health")
|
||||
return {"base_url": API_BASE_URL, "result": payload}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def get_resource(resource_id: str) -> dict[str, Any]:
|
||||
"""Fetch one resource by ID from the upstream API."""
|
||||
payload = _request("GET", f"/resources/{resource_id}")
|
||||
return {"resource_id": resource_id, "data": payload}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def search_resources(query: str, limit: int = 10) -> dict[str, Any]:
|
||||
"""Search upstream resources by query string."""
|
||||
payload = _request("GET", "/resources", params={"q": query, "limit": limit})
|
||||
return {"query": query, "limit": limit, "results": payload}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
DATABASE_PATH = os.getenv("SQLITE_PATH", "./app.db")
|
||||
MAX_ROWS = int(os.getenv("SQLITE_MAX_ROWS", "200"))
|
||||
TABLE_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _connect() -> sqlite3.Connection:
|
||||
return sqlite3.connect(f"file:{DATABASE_PATH}?mode=ro", uri=True)
|
||||
|
||||
|
||||
def _reject_mutation(sql: str) -> None:
|
||||
normalized = sql.strip().lower()
|
||||
if not normalized.startswith("select"):
|
||||
raise ValueError("Only SELECT queries are allowed")
|
||||
|
||||
|
||||
def _validate_table_name(table_name: str) -> str:
|
||||
if not TABLE_NAME_RE.fullmatch(table_name):
|
||||
raise ValueError("Invalid table name")
|
||||
return table_name
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def list_tables() -> list[str]:
|
||||
"""List user-defined SQLite tables."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
|
||||
).fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def describe_table(table_name: str) -> list[dict[str, Any]]:
|
||||
"""Describe columns for a SQLite table."""
|
||||
safe_table_name = _validate_table_name(table_name)
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(f"PRAGMA table_info({safe_table_name})").fetchall()
|
||||
return [
|
||||
{
|
||||
"cid": row[0],
|
||||
"name": row[1],
|
||||
"type": row[2],
|
||||
"notnull": bool(row[3]),
|
||||
"default": row[4],
|
||||
"pk": bool(row[5]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def query(sql: str, limit: int = 50) -> dict[str, Any]:
|
||||
"""Run a read-only SELECT query and return rows plus column names."""
|
||||
_reject_mutation(sql)
|
||||
safe_limit = max(0, min(limit, MAX_ROWS))
|
||||
wrapped_sql = f"SELECT * FROM ({sql.strip().rstrip(';')}) LIMIT {safe_limit}"
|
||||
with _connect() as conn:
|
||||
cursor = conn.execute(wrapped_sql)
|
||||
columns = [column[0] for column in cursor.description or []]
|
||||
rows = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
return {"limit": safe_limit, "columns": columns, "rows": rows}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
|
||||
def _read_text(path: str) -> str:
|
||||
file_path = Path(path).expanduser()
|
||||
try:
|
||||
return file_path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError as exc:
|
||||
raise ValueError(f"File not found: {file_path}") from exc
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError(f"File is not valid UTF-8 text: {file_path}") from exc
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def summarize_text_file(path: str, preview_chars: int = 1200) -> dict[str, int | str]:
|
||||
"""Return basic metadata and a preview for a UTF-8 text file."""
|
||||
file_path = Path(path).expanduser()
|
||||
text = _read_text(path)
|
||||
return {
|
||||
"path": str(file_path),
|
||||
"characters": len(text),
|
||||
"lines": len(text.splitlines()),
|
||||
"preview": text[:preview_chars],
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def search_text_file(path: str, needle: str, max_matches: int = 20) -> dict[str, Any]:
|
||||
"""Find matching lines in a UTF-8 text file."""
|
||||
file_path = Path(path).expanduser()
|
||||
matches: list[dict[str, Any]] = []
|
||||
for line_number, line in enumerate(_read_text(path).splitlines(), start=1):
|
||||
if needle.lower() in line.lower():
|
||||
matches.append({"line_number": line_number, "line": line})
|
||||
if len(matches) >= max_matches:
|
||||
break
|
||||
return {"path": str(file_path), "needle": needle, "matches": matches}
|
||||
|
||||
|
||||
@mcp.resource("file://{path}")
|
||||
def read_file_resource(path: str) -> str:
|
||||
"""Expose a text file as a resource."""
|
||||
return _read_text(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -0,0 +1,235 @@
|
||||
---
|
||||
name: bioinformatics
|
||||
description: Gateway to 400+ bioinformatics skills from bioSkills and ClawBio. Covers genomics, transcriptomics, single-cell, variant calling, pharmacogenomics, metagenomics, structural biology, and more. Fetches domain-specific reference material on demand.
|
||||
version: 1.0.0
|
||||
platforms: [linux, macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [bioinformatics, genomics, sequencing, biology, research, science]
|
||||
category: research
|
||||
---
|
||||
|
||||
# Bioinformatics Skills Gateway
|
||||
|
||||
Use when asked about bioinformatics, genomics, sequencing, variant calling, gene expression, single-cell analysis, protein structure, pharmacogenomics, metagenomics, phylogenetics, or any computational biology task.
|
||||
|
||||
This skill is a gateway to two open-source bioinformatics skill libraries. Instead of bundling hundreds of domain-specific skills, it indexes them and fetches what you need on demand.
|
||||
|
||||
## Sources
|
||||
|
||||
◆ **bioSkills** — 385 reference skills (code patterns, parameter guides, decision trees)
|
||||
Repo: https://github.com/GPTomics/bioSkills
|
||||
Format: SKILL.md per topic with code examples. Python/R/CLI.
|
||||
|
||||
◆ **ClawBio** — 33 runnable pipeline skills (executable scripts, reproducibility bundles)
|
||||
Repo: https://github.com/ClawBio/ClawBio
|
||||
Format: Python scripts with demos. Each analysis exports report.md + commands.sh + environment.yml.
|
||||
|
||||
## How to fetch and use a skill
|
||||
|
||||
1. Identify the domain and skill name from the index below.
|
||||
2. Clone the relevant repo (shallow clone to save time):
|
||||
```bash
|
||||
# bioSkills (reference material)
|
||||
git clone --depth 1 https://github.com/GPTomics/bioSkills.git /tmp/bioSkills
|
||||
|
||||
# ClawBio (runnable pipelines)
|
||||
git clone --depth 1 https://github.com/ClawBio/ClawBio.git /tmp/ClawBio
|
||||
```
|
||||
3. Read the specific skill:
|
||||
```bash
|
||||
# bioSkills — each skill is at: <category>/<skill-name>/SKILL.md
|
||||
cat /tmp/bioSkills/variant-calling/gatk-variant-calling/SKILL.md
|
||||
|
||||
# ClawBio — each skill is at: skills/<skill-name>/
|
||||
cat /tmp/ClawBio/skills/pharmgx-reporter/README.md
|
||||
```
|
||||
4. Follow the fetched skill as reference material. These are NOT Hermes-format skills — treat them as expert domain guides. They contain correct parameters, proper tool flags, and validated pipelines.
|
||||
|
||||
## Skill Index by Domain
|
||||
|
||||
### Sequence Fundamentals
|
||||
bioSkills:
|
||||
sequence-io/ — read-sequences, write-sequences, format-conversion, batch-processing, compressed-files, fastq-quality, filter-sequences, paired-end-fastq, sequence-statistics
|
||||
sequence-manipulation/ — seq-objects, reverse-complement, transcription-translation, motif-search, codon-usage, sequence-properties, sequence-slicing
|
||||
ClawBio:
|
||||
seq-wrangler — Sequence QC, alignment, and BAM processing (wraps FastQC, BWA, SAMtools)
|
||||
|
||||
### Read QC & Alignment
|
||||
bioSkills:
|
||||
read-qc/ — quality-reports, fastp-workflow, adapter-trimming, quality-filtering, umi-processing, contamination-screening, rnaseq-qc
|
||||
read-alignment/ — bwa-alignment, star-alignment, hisat2-alignment, bowtie2-alignment
|
||||
alignment-files/ — sam-bam-basics, alignment-sorting, alignment-filtering, bam-statistics, duplicate-handling, pileup-generation
|
||||
|
||||
### Variant Calling & Annotation
|
||||
bioSkills:
|
||||
variant-calling/ — gatk-variant-calling, deepvariant, variant-calling (bcftools), joint-calling, structural-variant-calling, filtering-best-practices, variant-annotation, variant-normalization, vcf-basics, vcf-manipulation, vcf-statistics, consensus-sequences, clinical-interpretation
|
||||
ClawBio:
|
||||
vcf-annotator — VEP + ClinVar + gnomAD annotation with ancestry-aware context
|
||||
variant-annotation — Variant annotation pipeline
|
||||
|
||||
### Differential Expression (Bulk RNA-seq)
|
||||
bioSkills:
|
||||
differential-expression/ — deseq2-basics, edger-basics, batch-correction, de-results, de-visualization, timeseries-de
|
||||
rna-quantification/ — alignment-free-quant (Salmon/kallisto), featurecounts-counting, tximport-workflow, count-matrix-qc
|
||||
expression-matrix/ — counts-ingest, gene-id-mapping, metadata-joins, sparse-handling
|
||||
ClawBio:
|
||||
rnaseq-de — Full DE pipeline with QC, normalization, and visualization
|
||||
diff-visualizer — Rich visualization and reporting for DE results
|
||||
|
||||
### Single-Cell RNA-seq
|
||||
bioSkills:
|
||||
single-cell/ — preprocessing, clustering, batch-integration, cell-annotation, cell-communication, doublet-detection, markers-annotation, trajectory-inference, multimodal-integration, perturb-seq, scatac-analysis, lineage-tracing, metabolite-communication, data-io
|
||||
ClawBio:
|
||||
scrna-orchestrator — Full Scanpy pipeline (QC, clustering, markers, annotation)
|
||||
scrna-embedding — scVI-based latent embedding and batch integration
|
||||
|
||||
### Spatial Transcriptomics
|
||||
bioSkills:
|
||||
spatial-transcriptomics/ — spatial-data-io, spatial-preprocessing, spatial-domains, spatial-deconvolution, spatial-communication, spatial-neighbors, spatial-statistics, spatial-visualization, spatial-multiomics, spatial-proteomics, image-analysis
|
||||
|
||||
### Epigenomics
|
||||
bioSkills:
|
||||
chip-seq/ — peak-calling, differential-binding, motif-analysis, peak-annotation, chipseq-qc, chipseq-visualization, super-enhancers
|
||||
atac-seq/ — atac-peak-calling, atac-qc, differential-accessibility, footprinting, motif-deviation, nucleosome-positioning
|
||||
methylation-analysis/ — bismark-alignment, methylation-calling, dmr-detection, methylkit-analysis
|
||||
hi-c-analysis/ — hic-data-io, tad-detection, loop-calling, compartment-analysis, contact-pairs, matrix-operations, hic-visualization, hic-differential
|
||||
ClawBio:
|
||||
methylation-clock — Epigenetic age estimation
|
||||
|
||||
### Pharmacogenomics & Clinical
|
||||
bioSkills:
|
||||
clinical-databases/ — clinvar-lookup, gnomad-frequencies, dbsnp-queries, pharmacogenomics, polygenic-risk, hla-typing, variant-prioritization, somatic-signatures, tumor-mutational-burden, myvariant-queries
|
||||
ClawBio:
|
||||
pharmgx-reporter — PGx report from 23andMe/AncestryDNA (12 genes, 31 SNPs, 51 drugs)
|
||||
drug-photo — Photo of medication → personalized PGx dosage card (via vision)
|
||||
clinpgx — ClinPGx API for gene-drug data and CPIC guidelines
|
||||
gwas-lookup — Federated variant lookup across 9 genomic databases
|
||||
gwas-prs — Polygenic risk scores from consumer genetic data
|
||||
nutrigx_advisor — Personalized nutrition from consumer genetic data
|
||||
|
||||
### Population Genetics & GWAS
|
||||
bioSkills:
|
||||
population-genetics/ — association-testing (PLINK GWAS), plink-basics, population-structure, linkage-disequilibrium, scikit-allel-analysis, selection-statistics
|
||||
causal-genomics/ — mendelian-randomization, fine-mapping, colocalization-analysis, mediation-analysis, pleiotropy-detection
|
||||
phasing-imputation/ — haplotype-phasing, genotype-imputation, imputation-qc, reference-panels
|
||||
ClawBio:
|
||||
claw-ancestry-pca — Ancestry PCA against SGDP reference panel
|
||||
|
||||
### Metagenomics & Microbiome
|
||||
bioSkills:
|
||||
metagenomics/ — kraken-classification, metaphlan-profiling, abundance-estimation, functional-profiling, amr-detection, strain-tracking, metagenome-visualization
|
||||
microbiome/ — amplicon-processing, diversity-analysis, differential-abundance, taxonomy-assignment, functional-prediction, qiime2-workflow
|
||||
ClawBio:
|
||||
claw-metagenomics — Shotgun metagenomics profiling (taxonomy, resistome, functional pathways)
|
||||
|
||||
### Genome Assembly & Annotation
|
||||
bioSkills:
|
||||
genome-assembly/ — hifi-assembly, long-read-assembly, short-read-assembly, metagenome-assembly, assembly-polishing, assembly-qc, scaffolding, contamination-detection
|
||||
genome-annotation/ — eukaryotic-gene-prediction, prokaryotic-annotation, functional-annotation, ncrna-annotation, repeat-annotation, annotation-transfer
|
||||
long-read-sequencing/ — basecalling, long-read-alignment, long-read-qc, clair3-variants, structural-variants, medaka-polishing, nanopore-methylation, isoseq-analysis
|
||||
|
||||
### Structural Biology & Chemoinformatics
|
||||
bioSkills:
|
||||
structural-biology/ — alphafold-predictions, modern-structure-prediction, structure-io, structure-navigation, structure-modification, geometric-analysis
|
||||
chemoinformatics/ — molecular-io, molecular-descriptors, similarity-searching, substructure-search, virtual-screening, admet-prediction, reaction-enumeration
|
||||
ClawBio:
|
||||
struct-predictor — Local AlphaFold/Boltz/Chai structure prediction with comparison
|
||||
|
||||
### Proteomics
|
||||
bioSkills:
|
||||
proteomics/ — data-import, peptide-identification, protein-inference, quantification, differential-abundance, dia-analysis, ptm-analysis, proteomics-qc, spectral-libraries
|
||||
ClawBio:
|
||||
proteomics-de — Proteomics differential expression
|
||||
|
||||
### Pathway Analysis & Gene Networks
|
||||
bioSkills:
|
||||
pathway-analysis/ — go-enrichment, gsea, kegg-pathways, reactome-pathways, wikipathways, enrichment-visualization
|
||||
gene-regulatory-networks/ — scenic-regulons, coexpression-networks, differential-networks, multiomics-grn, perturbation-simulation
|
||||
|
||||
### Immunoinformatics
|
||||
bioSkills:
|
||||
immunoinformatics/ — mhc-binding-prediction, epitope-prediction, neoantigen-prediction, immunogenicity-scoring, tcr-epitope-binding
|
||||
tcr-bcr-analysis/ — mixcr-analysis, scirpy-analysis, immcantation-analysis, repertoire-visualization, vdjtools-analysis
|
||||
|
||||
### CRISPR & Genome Engineering
|
||||
bioSkills:
|
||||
crispr-screens/ — mageck-analysis, jacks-analysis, hit-calling, screen-qc, library-design, crispresso-editing, base-editing-analysis, batch-correction
|
||||
genome-engineering/ — grna-design, off-target-prediction, hdr-template-design, base-editing-design, prime-editing-design
|
||||
|
||||
### Workflow Management
|
||||
bioSkills:
|
||||
workflow-management/ — snakemake-workflows, nextflow-pipelines, cwl-workflows, wdl-workflows
|
||||
ClawBio:
|
||||
repro-enforcer — Export any analysis as reproducibility bundle (Conda env + Singularity + checksums)
|
||||
galaxy-bridge — Access 8,000+ Galaxy tools from usegalaxy.org
|
||||
|
||||
### Specialized Domains
|
||||
bioSkills:
|
||||
alternative-splicing/ — splicing-quantification, differential-splicing, isoform-switching, sashimi-plots, single-cell-splicing, splicing-qc
|
||||
ecological-genomics/ — edna-metabarcoding, landscape-genomics, conservation-genetics, biodiversity-metrics, community-ecology, species-delimitation
|
||||
epidemiological-genomics/ — pathogen-typing, variant-surveillance, phylodynamics, transmission-inference, amr-surveillance
|
||||
liquid-biopsy/ — cfdna-preprocessing, ctdna-mutation-detection, fragment-analysis, tumor-fraction-estimation, methylation-based-detection, longitudinal-monitoring
|
||||
epitranscriptomics/ — m6a-peak-calling, m6a-differential, m6anet-analysis, merip-preprocessing, modification-visualization
|
||||
metabolomics/ — xcms-preprocessing, metabolite-annotation, normalization-qc, statistical-analysis, pathway-mapping, lipidomics, targeted-analysis, msdial-preprocessing
|
||||
flow-cytometry/ — fcs-handling, gating-analysis, compensation-transformation, clustering-phenotyping, differential-analysis, cytometry-qc, doublet-detection, bead-normalization
|
||||
systems-biology/ — flux-balance-analysis, metabolic-reconstruction, gene-essentiality, context-specific-models, model-curation
|
||||
rna-structure/ — secondary-structure-prediction, ncrna-search, structure-probing
|
||||
|
||||
### Data Visualization & Reporting
|
||||
bioSkills:
|
||||
data-visualization/ — ggplot2-fundamentals, heatmaps-clustering, volcano-customization, circos-plots, genome-browser-tracks, interactive-visualization, multipanel-figures, network-visualization, upset-plots, color-palettes, specialized-omics-plots, genome-tracks
|
||||
reporting/ — rmarkdown-reports, quarto-reports, jupyter-reports, automated-qc-reports, figure-export
|
||||
ClawBio:
|
||||
profile-report — Analysis profile reporting
|
||||
data-extractor — Extract numerical data from scientific figure images (via vision)
|
||||
lit-synthesizer — PubMed/bioRxiv search, summarization, citation graphs
|
||||
pubmed-summariser — Gene/disease PubMed search with structured briefing
|
||||
|
||||
### Database Access
|
||||
bioSkills:
|
||||
database-access/ — entrez-search, entrez-fetch, entrez-link, blast-searches, local-blast, sra-data, geo-data, uniprot-access, batch-downloads, interaction-databases, sequence-similarity
|
||||
ClawBio:
|
||||
ukb-navigator — Semantic search across 12,000+ UK Biobank fields
|
||||
clinical-trial-finder — Clinical trial discovery
|
||||
|
||||
### Experimental Design
|
||||
bioSkills:
|
||||
experimental-design/ — power-analysis, sample-size, batch-design, multiple-testing
|
||||
|
||||
### Machine Learning for Omics
|
||||
bioSkills:
|
||||
machine-learning/ — omics-classifiers, biomarker-discovery, survival-analysis, model-validation, prediction-explanation, atlas-mapping
|
||||
ClawBio:
|
||||
claw-semantic-sim — Semantic similarity index for disease literature (PubMedBERT)
|
||||
omics-target-evidence-mapper — Aggregate target-level evidence across omics sources
|
||||
|
||||
## Environment Setup
|
||||
|
||||
These skills assume a bioinformatics workstation. Common dependencies:
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install biopython pysam cyvcf2 pybedtools pyBigWig scikit-allel anndata scanpy mygene
|
||||
|
||||
# R/Bioconductor
|
||||
Rscript -e 'BiocManager::install(c("DESeq2","edgeR","Seurat","clusterProfiler","methylKit"))'
|
||||
|
||||
# CLI tools (Ubuntu/Debian)
|
||||
sudo apt install samtools bcftools ncbi-blast+ minimap2 bedtools
|
||||
|
||||
# CLI tools (macOS)
|
||||
brew install samtools bcftools blast minimap2 bedtools
|
||||
|
||||
# Or via Conda (recommended for reproducibility)
|
||||
conda install -c bioconda samtools bcftools blast minimap2 bedtools fastp kraken2
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- The fetched skills are NOT in Hermes SKILL.md format. They use their own structure (bioSkills: code pattern cookbooks; ClawBio: README + Python scripts). Read them as expert reference material.
|
||||
- bioSkills are reference guides — they show correct parameters and code patterns but aren't executable pipelines.
|
||||
- ClawBio skills are executable — many have `--demo` flags and can be run directly.
|
||||
- Both repos assume bioinformatics tools are installed. Check prerequisites before running pipelines.
|
||||
- For ClawBio, run `pip install -r requirements.txt` in the cloned repo first.
|
||||
- Genomic data files can be very large. Be mindful of disk space when downloading reference genomes, SRA datasets, or building indices.
|
||||
@@ -0,0 +1,422 @@
|
||||
---
|
||||
name: oss-forensics
|
||||
description: |
|
||||
Supply chain investigation, evidence recovery, and forensic analysis for GitHub repositories.
|
||||
Covers deleted commit recovery, force-push detection, IOC extraction, multi-source evidence
|
||||
collection, hypothesis formation/validation, and structured forensic reporting.
|
||||
Inspired by RAPTOR's 1800+ line OSS Forensics system.
|
||||
category: security
|
||||
triggers:
|
||||
- "investigate this repository"
|
||||
- "investigate [owner/repo]"
|
||||
- "check for supply chain compromise"
|
||||
- "recover deleted commits"
|
||||
- "forensic analysis of [owner/repo]"
|
||||
- "was this repo compromised"
|
||||
- "supply chain attack"
|
||||
- "suspicious commit"
|
||||
- "force push detected"
|
||||
- "IOC extraction"
|
||||
toolsets:
|
||||
- terminal
|
||||
- web
|
||||
- file
|
||||
- delegation
|
||||
---
|
||||
|
||||
# OSS Security Forensics Skill
|
||||
|
||||
A 7-phase multi-agent investigation framework for researching open-source supply chain attacks.
|
||||
Adapted from RAPTOR's forensics system. Covers GitHub Archive, Wayback Machine, GitHub API,
|
||||
local git analysis, IOC extraction, evidence-backed hypothesis formation and validation,
|
||||
and final forensic report generation.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Anti-Hallucination Guardrails
|
||||
|
||||
Read these before every investigation step. Violating them invalidates the report.
|
||||
|
||||
1. **Evidence-First Rule**: Every claim in any report, hypothesis, or summary MUST cite at least one evidence ID (`EV-XXXX`). Assertions without citations are forbidden.
|
||||
2. **STAY IN YOUR LANE**: Each sub-agent (investigator) has a single data source. Do NOT mix sources. The GH Archive investigator does not query the GitHub API, and vice versa. Role boundaries are hard.
|
||||
3. **Fact vs. Hypothesis Separation**: Mark all unverified inferences with `[HYPOTHESIS]`. Only statements verified against original sources may be stated as facts.
|
||||
4. **No Evidence Fabrication**: The hypothesis validator MUST mechanically check that every cited evidence ID actually exists in the evidence store before accepting a hypothesis.
|
||||
5. **Proof-Required Disproval**: A hypothesis cannot be dismissed without a specific, evidence-backed counter-argument. "No evidence found" is not sufficient to disprove—it only makes a hypothesis inconclusive.
|
||||
6. **SHA/URL Double-Verification**: Any commit SHA, URL, or external identifier cited as evidence must be independently confirmed from at least two sources before being marked as verified.
|
||||
7. **Suspicious Code Rule**: Never run code found inside the investigated repository locally. Analyze statically only, or use `execute_code` in a sandboxed environment.
|
||||
8. **Secret Redaction**: Any API keys, tokens, or credentials discovered during investigation must be redacted in the final report. Log them internally only.
|
||||
|
||||
---
|
||||
|
||||
## Example Scenarios
|
||||
|
||||
- **Scenario A: Dependency Confusion**: A malicious package `internal-lib-v2` is uploaded to NPM with a higher version than the internal one. The investigator must track when this package was first seen and if any PushEvents in the target repo updated `package.json` to this version.
|
||||
- **Scenario B: Maintainer Takeover**: A long-term contributor's account is used to push a backdoored `.github/workflows/build.yml`. The investigator looks for PushEvents from this user after a long period of inactivity or from a new IP/location (if detectable via BigQuery).
|
||||
- **Scenario C: Force-Push Hide**: A developer accidentally commits a production secret, then force-pushes to "fix" it. The investigator uses `git fsck` and GH Archive to recover the original commit SHA and verify what was leaked.
|
||||
|
||||
---
|
||||
|
||||
> **Path convention**: Throughout this skill, `SKILL_DIR` refers to the root of this skill's
|
||||
> installation directory (the folder containing this `SKILL.md`). When the skill is loaded,
|
||||
> resolve `SKILL_DIR` to the actual path — e.g. `~/.hermes/skills/security/oss-forensics/`
|
||||
> or the `optional-skills/` equivalent. All script and template references are relative to it.
|
||||
|
||||
## Phase 0: Initialization
|
||||
|
||||
1. Create investigation working directory:
|
||||
```bash
|
||||
mkdir investigation_$(echo "REPO_NAME" | tr '/' '_')
|
||||
cd investigation_$(echo "REPO_NAME" | tr '/' '_')
|
||||
```
|
||||
2. Initialize the evidence store:
|
||||
```bash
|
||||
python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list
|
||||
```
|
||||
3. Copy the forensic report template:
|
||||
```bash
|
||||
cp SKILL_DIR/templates/forensic-report.md ./investigation-report.md
|
||||
```
|
||||
4. Create an `iocs.md` file to track Indicators of Compromise as they are discovered.
|
||||
5. Record the investigation start time, target repository, and stated investigation goal.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Prompt Parsing and IOC Extraction
|
||||
|
||||
**Goal**: Extract all structured investigative targets from the user's request.
|
||||
|
||||
**Actions**:
|
||||
- Parse the user prompt and extract:
|
||||
- Target repository (`owner/repo`)
|
||||
- Target actors (GitHub handles, email addresses)
|
||||
- Time window of interest (commit date ranges, PR timestamps)
|
||||
- Provided Indicators of Compromise: commit SHAs, file paths, package names, IP addresses, domains, API keys/tokens, malicious URLs
|
||||
- Any linked vendor security reports or blog posts
|
||||
|
||||
**Tools**: Reasoning only, or `execute_code` for regex extraction from large text blocks.
|
||||
|
||||
**Output**: Populate `iocs.md` with extracted IOCs. Each IOC must have:
|
||||
- Type (from: COMMIT_SHA, FILE_PATH, API_KEY, SECRET, IP_ADDRESS, DOMAIN, PACKAGE_NAME, ACTOR_USERNAME, MALICIOUS_URL, OTHER)
|
||||
- Value
|
||||
- Source (user-provided, inferred)
|
||||
|
||||
**Reference**: See [evidence-types.md](./references/evidence-types.md) for IOC taxonomy.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Parallel Evidence Collection
|
||||
|
||||
Spawn up to 5 specialist investigator sub-agents using `delegate_task` (batch mode, max 3 concurrent). Each investigator has a **single data source** and must not mix sources.
|
||||
|
||||
> **Orchestrator note**: Pass the IOC list from Phase 1 and the investigation time window in the `context` field of each delegated task.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 1: Local Git Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the LOCAL GIT REPOSITORY ONLY. Do not call any external APIs.
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/OWNER/REPO.git target_repo && cd target_repo
|
||||
|
||||
# Full commit log with stats
|
||||
git log --all --full-history --stat --format="%H|%ae|%an|%ai|%s" > ../git_log.txt
|
||||
|
||||
# Detect force-push evidence (orphaned/dangling commits)
|
||||
git fsck --lost-found --unreachable 2>&1 | grep commit > ../dangling_commits.txt
|
||||
|
||||
# Check reflog for rewritten history
|
||||
git reflog --all > ../reflog.txt
|
||||
|
||||
# List ALL branches including deleted remote refs
|
||||
git branch -a -v > ../branches.txt
|
||||
|
||||
# Find suspicious large binary additions
|
||||
git log --all --diff-filter=A --name-only --format="%H %ai" -- "*.so" "*.dll" "*.exe" "*.bin" > ../binary_additions.txt
|
||||
|
||||
# Check for GPG signature anomalies
|
||||
git log --show-signature --format="%H %ai %aN" > ../signature_check.txt 2>&1
|
||||
```
|
||||
|
||||
**Evidence to collect** (add via `python3 SKILL_DIR/scripts/evidence-store.py add`):
|
||||
- Each dangling commit SHA → type: `git`
|
||||
- Force-push evidence (reflog showing history rewrite) → type: `git`
|
||||
- Unsigned commits from verified contributors → type: `git`
|
||||
- Suspicious binary file additions → type: `git`
|
||||
|
||||
**Reference**: See [recovery-techniques.md](./references/recovery-techniques.md) for accessing force-pushed commits.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 2: GitHub API Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the GITHUB REST API ONLY. Do not run git commands locally.
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Commits (paginated)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/commits?per_page=100" > api_commits.json
|
||||
|
||||
# Pull Requests including closed/deleted
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/pulls?state=all&per_page=100" > api_prs.json
|
||||
|
||||
# Issues
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/issues?state=all&per_page=100" > api_issues.json
|
||||
|
||||
# Contributors and collaborator changes
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contributors" > api_contributors.json
|
||||
|
||||
# Repository events (last 300)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/events?per_page=100" > api_events.json
|
||||
|
||||
# Check specific suspicious commit SHA details
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/commits/SHA" > commit_detail.json
|
||||
|
||||
# Releases
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/releases?per_page=100" > api_releases.json
|
||||
|
||||
# Check if a specific commit exists (force-pushed commits may 404 on commits/ but succeed on git/commits/)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/commits/SHA" | jq .sha
|
||||
```
|
||||
|
||||
**Cross-reference targets** (flag discrepancies as evidence):
|
||||
- PR exists in archive but missing from API → evidence of deletion
|
||||
- Contributor in archive events but not in contributors list → evidence of permission revocation
|
||||
- Commit in archive PushEvents but not in API commit list → evidence of force-push/deletion
|
||||
|
||||
**Reference**: See [evidence-types.md](./references/evidence-types.md) for GH event types.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 3: Wayback Machine Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the WAYBACK MACHINE CDX API ONLY. Do not use the GitHub API.
|
||||
|
||||
**Goal**: Recover deleted GitHub pages (READMEs, issues, PRs, releases, wiki pages).
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Search for archived snapshots of the repo main page
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO&output=json&limit=100&from=YYYYMMDD&to=YYYYMMDD" > wayback_main.json
|
||||
|
||||
# Search for a specific deleted issue
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/issues/NUM&output=json&limit=50" > wayback_issue_NUM.json
|
||||
|
||||
# Search for a specific deleted PR
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/pull/NUM&output=json&limit=50" > wayback_pr_NUM.json
|
||||
|
||||
# Fetch the best snapshot of a page
|
||||
# Use the Wayback Machine URL: https://web.archive.org/web/TIMESTAMP/ORIGINAL_URL
|
||||
# Example: https://web.archive.org/web/20240101000000*/github.com/OWNER/REPO
|
||||
|
||||
# Advanced: Search for deleted releases/tags
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/releases/tag/*&output=json" > wayback_tags.json
|
||||
|
||||
# Advanced: Search for historical wiki changes
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/wiki/*&output=json" > wayback_wiki.json
|
||||
```
|
||||
|
||||
**Evidence to collect**:
|
||||
- Archived snapshots of deleted issues/PRs with their content
|
||||
- Historical README versions showing changes
|
||||
- Evidence of content present in archive but missing from current GitHub state
|
||||
|
||||
**Reference**: See [github-archive-guide.md](./references/github-archive-guide.md) for CDX API parameters.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 4: GH Archive / BigQuery Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query GITHUB ARCHIVE via BIGQUERY ONLY. This is a tamper-proof record of all public GitHub events.
|
||||
|
||||
> **Prerequisites**: Requires Google Cloud credentials with BigQuery access (`gcloud auth application-default login`). If unavailable, skip this investigator and note it in the report.
|
||||
|
||||
**Cost Optimization Rules** (MANDATORY):
|
||||
1. ALWAYS run a `--dry_run` before every query to estimate cost.
|
||||
2. Use `_TABLE_SUFFIX` to filter by date range and minimize scanned data.
|
||||
3. Only SELECT the columns you need.
|
||||
4. Add a LIMIT unless aggregating.
|
||||
|
||||
```bash
|
||||
# Template: safe BigQuery query for PushEvents to OWNER/REPO
|
||||
bq query --use_legacy_sql=false --dry_run "
|
||||
SELECT created_at, actor.login, payload.commits, payload.before, payload.head,
|
||||
payload.size, payload.distinct_size
|
||||
FROM \`githubarchive.month.*\`
|
||||
WHERE _TABLE_SUFFIX BETWEEN 'YYYYMM' AND 'YYYYMM'
|
||||
AND type = 'PushEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
LIMIT 1000
|
||||
"
|
||||
# If cost is acceptable, re-run without --dry_run
|
||||
|
||||
# Detect force-pushes: zero-distinct_size PushEvents mean commits were force-erased
|
||||
# payload.distinct_size = 0 AND payload.size > 0 → force push indicator
|
||||
|
||||
# Check for deleted branch events
|
||||
bq query --use_legacy_sql=false "
|
||||
SELECT created_at, actor.login, payload.ref, payload.ref_type
|
||||
FROM \`githubarchive.month.*\`
|
||||
WHERE _TABLE_SUFFIX BETWEEN 'YYYYMM' AND 'YYYYMM'
|
||||
AND type = 'DeleteEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
LIMIT 200
|
||||
"
|
||||
```
|
||||
|
||||
**Evidence to collect**:
|
||||
- Force-push events (payload.size > 0, payload.distinct_size = 0)
|
||||
- DeleteEvents for branches/tags
|
||||
- WorkflowRunEvents for suspicious CI/CD automation
|
||||
- PushEvents that precede a "gap" in the git log (evidence of rewrite)
|
||||
|
||||
**Reference**: See [github-archive-guide.md](./references/github-archive-guide.md) for all 12 event types and query patterns.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 5: IOC Enrichment Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You enrich EXISTING IOCs from Phase 1 using passive public sources ONLY. Do not execute any code from the target repository.
|
||||
|
||||
**Actions**:
|
||||
- For each commit SHA: attempt recovery via direct GitHub URL (`github.com/OWNER/REPO/commit/SHA.patch`)
|
||||
- For each domain/IP: check passive DNS, WHOIS records (via `web_extract` on public WHOIS services)
|
||||
- For each package name: check npm/PyPI for matching malicious package reports
|
||||
- For each actor username: check GitHub profile, contribution history, account age
|
||||
- Recover force-pushed commits using 3 methods (see [recovery-techniques.md](./references/recovery-techniques.md))
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Evidence Consolidation
|
||||
|
||||
After all investigators complete:
|
||||
|
||||
1. Run `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list` to see all collected evidence.
|
||||
2. For each piece of evidence, verify the `content_sha256` hash matches the original source.
|
||||
3. Group evidence by:
|
||||
- **Timeline**: Sort all timestamped evidence chronologically
|
||||
- **Actor**: Group by GitHub handle or email
|
||||
- **IOC**: Link evidence to the IOC it relates to
|
||||
4. Identify **discrepancies**: items present in one source but absent in another (key deletion indicators).
|
||||
5. Flag evidence as `[VERIFIED]` (confirmed from 2+ independent sources) or `[UNVERIFIED]` (single source only).
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Hypothesis Formation
|
||||
|
||||
A hypothesis must:
|
||||
- State a specific claim (e.g., "Actor X force-pushed to BRANCH on DATE to erase commit SHA")
|
||||
- Cite at least 2 evidence IDs that support it (`EV-XXXX`, `EV-YYYY`)
|
||||
- Identify what evidence would disprove it
|
||||
- Be labeled `[HYPOTHESIS]` until validated
|
||||
|
||||
**Common hypothesis templates** (see [investigation-templates.md](./references/investigation-templates.md)):
|
||||
- Maintainer Compromise: legitimate account used post-takeover to inject malicious code
|
||||
- Dependency Confusion: package name squatting to intercept installs
|
||||
- CI/CD Injection: malicious workflow changes to run code during builds
|
||||
- Typosquatting: near-identical package name targeting misspellers
|
||||
- Credential Leak: token/key accidentally committed then force-pushed to erase
|
||||
|
||||
For each hypothesis, spawn a `delegate_task` sub-agent to attempt to find disconfirming evidence before confirming.
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Hypothesis Validation
|
||||
|
||||
The validator sub-agent MUST mechanically check:
|
||||
|
||||
1. For each hypothesis, extract all cited evidence IDs.
|
||||
2. Verify each ID exists in `evidence.json` (hard failure if any ID is missing → hypothesis rejected as potentially fabricated).
|
||||
3. Verify each `[VERIFIED]` piece of evidence was confirmed from 2+ sources.
|
||||
4. Check logical consistency: does the timeline depicted by the evidence support the hypothesis?
|
||||
5. Check for alternative explanations: could the same evidence pattern arise from a benign cause?
|
||||
|
||||
**Output**:
|
||||
- `VALIDATED`: All evidence cited, verified, logically consistent, no plausible alternative explanation.
|
||||
- `INCONCLUSIVE`: Evidence supports hypothesis but alternative explanations exist or evidence is insufficient.
|
||||
- `REJECTED`: Missing evidence IDs, unverified evidence cited as fact, logical inconsistency detected.
|
||||
|
||||
Rejected hypotheses feed back into Phase 4 for refinement (max 3 iterations).
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Final Report Generation
|
||||
|
||||
Populate `investigation-report.md` using the template in [forensic-report.md](./templates/forensic-report.md).
|
||||
|
||||
**Mandatory sections**:
|
||||
- Executive Summary: one-paragraph verdict (Compromised / Clean / Inconclusive) with confidence level
|
||||
- Timeline: chronological reconstruction of all significant events with evidence citations
|
||||
- Validated Hypotheses: each with status and supporting evidence IDs
|
||||
- Evidence Registry: table of all `EV-XXXX` entries with source, type, and verification status
|
||||
- IOC List: all extracted and enriched Indicators of Compromise
|
||||
- Chain of Custody: how evidence was collected, from what sources, at what timestamps
|
||||
- Recommendations: immediate mitigations if compromise detected; monitoring recommendations
|
||||
|
||||
**Report rules**:
|
||||
- Every factual claim must have at least one `[EV-XXXX]` citation
|
||||
- Executive Summary must state confidence level (High / Medium / Low)
|
||||
- All secrets/credentials must be redacted to `[REDACTED]`
|
||||
|
||||
---
|
||||
|
||||
## Phase 7: Completion
|
||||
|
||||
1. Run final evidence count: `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list`
|
||||
2. Archive the full investigation directory.
|
||||
3. If compromise is confirmed:
|
||||
- List immediate mitigations (rotate credentials, pin dependency hashes, notify affected users)
|
||||
- Identify affected versions/packages
|
||||
- Note disclosure obligations (if a public package: coordinate with the package registry)
|
||||
4. Present the final `investigation-report.md` to the user.
|
||||
|
||||
---
|
||||
|
||||
## Ethical Use Guidelines
|
||||
|
||||
This skill is designed for **defensive security investigation** — protecting open-source software from supply chain attacks. It must not be used for:
|
||||
|
||||
- **Harassment or stalking** of contributors or maintainers
|
||||
- **Doxing** — correlating GitHub activity to real identities for malicious purposes
|
||||
- **Competitive intelligence** — investigating proprietary or internal repositories without authorization
|
||||
- **False accusations** — publishing investigation results without validated evidence (see anti-hallucination guardrails)
|
||||
|
||||
Investigations should be conducted with the principle of **minimal intrusion**: collect only the evidence necessary to validate or refute the hypothesis. When publishing results, follow responsible disclosure practices and coordinate with affected maintainers before public disclosure.
|
||||
|
||||
If the investigation reveals a genuine compromise, follow the coordinated vulnerability disclosure process:
|
||||
1. Notify the repository maintainers privately first
|
||||
2. Allow reasonable time for remediation (typically 90 days)
|
||||
3. Coordinate with package registries (npm, PyPI, etc.) if published packages are affected
|
||||
4. File a CVE if appropriate
|
||||
|
||||
---
|
||||
|
||||
## API Rate Limiting
|
||||
|
||||
GitHub REST API enforces rate limits that will interrupt large investigations if not managed.
|
||||
|
||||
**Authenticated requests**: 5,000/hour (requires `GITHUB_TOKEN` env var or `gh` CLI auth)
|
||||
**Unauthenticated requests**: 60/hour (unusable for investigations)
|
||||
|
||||
**Best practices**:
|
||||
- Always authenticate: `export GITHUB_TOKEN=ghp_...` or use `gh` CLI (auto-authenticates)
|
||||
- Use conditional requests (`If-None-Match` / `If-Modified-Since` headers) to avoid consuming quota on unchanged data
|
||||
- For paginated endpoints, fetch all pages in sequence — don't parallelize against the same endpoint
|
||||
- Check `X-RateLimit-Remaining` header; if below 100, pause for `X-RateLimit-Reset` timestamp
|
||||
- BigQuery has its own quotas (10 TiB/day free tier) — always dry-run first
|
||||
- Wayback Machine CDX API: no formal rate limit, but be courteous (1-2 req/sec max)
|
||||
|
||||
If rate-limited mid-investigation, record the partial results in the evidence store and note the limitation in the report.
|
||||
|
||||
---
|
||||
|
||||
## Reference Materials
|
||||
|
||||
- [github-archive-guide.md](./references/github-archive-guide.md) — BigQuery queries, CDX API, 12 event types
|
||||
- [evidence-types.md](./references/evidence-types.md) — IOC taxonomy, evidence source types, observation types
|
||||
- [recovery-techniques.md](./references/recovery-techniques.md) — Recovering deleted commits, PRs, issues
|
||||
- [investigation-templates.md](./references/investigation-templates.md) — Pre-built hypothesis templates per attack type
|
||||
- [evidence-store.py](./scripts/evidence-store.py) — CLI tool for managing the evidence JSON store
|
||||
- [forensic-report.md](./templates/forensic-report.md) — Structured report template
|
||||
@@ -0,0 +1,89 @@
|
||||
# Evidence Types Reference
|
||||
|
||||
Taxonomy of all evidence types, IOC types, GitHub event types, and observation types
|
||||
used in OSS forensic investigations.
|
||||
|
||||
---
|
||||
|
||||
## Evidence Source Types
|
||||
|
||||
| Type | Description | Example Sources |
|
||||
|------|-------------|-----------------|
|
||||
| `git` | Data from local git repository analysis | `git log`, `git fsck`, `git reflog`, `git blame` |
|
||||
| `gh_api` | Data from GitHub REST API responses | `/repos/.../commits`, `/repos/.../pulls`, `/repos/.../events` |
|
||||
| `gh_archive` | Data from GitHub Archive (BigQuery) | `githubarchive.month.*` BigQuery tables |
|
||||
| `web_archive` | Archived web pages from Wayback Machine | CDX API results, `web.archive.org/web/...` snapshots |
|
||||
| `ioc` | Indicator of Compromise from any source | Extracted from vendor reports, git history, network traces |
|
||||
| `analysis` | Derived insight from cross-source correlation | "SHA present in archive but absent from API" |
|
||||
| `vendor_report` | External security vendor or researcher report | CVE advisories, blog posts, NVD records |
|
||||
| `manual` | Manually recorded observation by investigator | Notes on behavioral patterns, timeline gaps |
|
||||
|
||||
---
|
||||
|
||||
## IOC Types
|
||||
|
||||
| Type | Description | Example |
|
||||
|------|-------------|---------|
|
||||
| `COMMIT_SHA` | A git commit hash linked to malicious activity | `abc123def456...` |
|
||||
| `FILE_PATH` | A suspicious file inside the repository | `src/utils/crypto.js`, `dist/index.min.js` |
|
||||
| `API_KEY` | An API key accidentally committed | `AKIA...` (AWS), `ghp_...` (GitHub PAT) |
|
||||
| `SECRET` | A generic secret / credential | Database password, private key blob |
|
||||
| `IP_ADDRESS` | A C2 server or attacker IP | `192.0.2.1` |
|
||||
| `DOMAIN` | A malicious or suspicious domain | `evil-cdn.io`, typosquatted package registry domain |
|
||||
| `PACKAGE_NAME` | A malicious or squatted package name | `colo-rs` (typosquatting `color`), `lodash-utils` |
|
||||
| `ACTOR_USERNAME` | A GitHub handle linked to the attack | `malicious-bot-account` |
|
||||
| `MALICIOUS_URL` | A URL to a malicious resource | `https://evil.example.com/payload.sh` |
|
||||
| `WORKFLOW_FILE` | A suspicious CI/CD workflow file | `.github/workflows/release.yml` |
|
||||
| `BRANCH_NAME` | A suspicious branch | `refs/heads/temp-fix-do-not-merge` |
|
||||
| `TAG_NAME` | A suspicious git tag | `v1.0.0-security-patch` |
|
||||
| `RELEASE_NAME` | A suspicious release | Release with no associated tag or changelog |
|
||||
| `OTHER` | Catch-all for unclassified IOCs | — |
|
||||
|
||||
---
|
||||
|
||||
## GitHub Archive Event Types (12 Types)
|
||||
|
||||
| Event Type | Forensic Relevance |
|
||||
|------------|-------------------|
|
||||
| `PushEvent` | Core: `payload.distinct_size=0` with `payload.size>0` → force push. `payload.before`/`payload.head` shows rewritten history. |
|
||||
| `PullRequestEvent` | Detects deleted PRs, rapid open→close patterns, PRs from new accounts |
|
||||
| `IssueEvent` | Detects deleted issues, coordinated labeling, rapid closure of vulnerability reports |
|
||||
| `IssueCommentEvent` | Deleted comments, rapid activity bursts |
|
||||
| `WatchEvent` | Star-farming campaigns (coordinated starring from new accounts) |
|
||||
| `ForkEvent` | Unusual fork patterns before malicious commit |
|
||||
| `CreateEvent` | Branch/tag creation: signals new release or code injection point |
|
||||
| `DeleteEvent` | Branch/tag deletion: critical — often used to hide traces |
|
||||
| `ReleaseEvent` | Unauthorized releases, release artifacts modified post-publish |
|
||||
| `MemberEvent` | Collaborator added/removed: maintainer compromise indicator |
|
||||
| `PublicEvent` | Repository made public (sometimes to drop malicious code briefly) |
|
||||
| `WorkflowRunEvent` | CI/CD pipeline executions: workflow injection, secret exfiltration |
|
||||
|
||||
---
|
||||
|
||||
## Evidence Verification States
|
||||
|
||||
| State | Meaning |
|
||||
|-------|---------|
|
||||
| `unverified` | Collected from a single source, not cross-referenced |
|
||||
| `single_source` | The primary source has been confirmed directly (e.g., SHA resolves on GitHub), but no second source |
|
||||
| `multi_source_verified` | Confirmed from 2+ independent sources (e.g., GH Archive AND GitHub API both show the same event) |
|
||||
|
||||
Only `multi_source_verified` evidence may be cited as fact in validated hypotheses.
|
||||
`unverified` and `single_source` evidence must be labeled `[UNVERIFIED]` or `[SINGLE-SOURCE]`.
|
||||
|
||||
---
|
||||
|
||||
## Observation Types (Patterned after RAPTOR)
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `CommitObservation` | Specific commit SHA with metadata (author, date, files changed) |
|
||||
| `ForceWashObservation` | Evidence that commits were force-erased from a branch |
|
||||
| `DanglingCommitObservation` | SHA present in git object store but unreachable from any ref |
|
||||
| `IssueObservation` | A GitHub issue (current or archived) with title, body, timestamp |
|
||||
| `PRObservation` | A GitHub PR (current or archived) with diff summary, reviewers |
|
||||
| `IOC` | A single Indicator of Compromise with context |
|
||||
| `TimelineGap` | A period with unusual absence of expected activity |
|
||||
| `ActorAnomalyObservation` | Behavioral anomaly for a specific GitHub actor |
|
||||
| `WorkflowAnomalyObservation` | Suspicious CI/CD workflow change or unexpected run |
|
||||
| `CrossSourceDiscrepancy` | Item present in one source but absent in another (strong deletion indicator) |
|
||||
@@ -0,0 +1,184 @@
|
||||
# GitHub Archive Query Guide (BigQuery)
|
||||
|
||||
GitHub Archive records every public event on GitHub as immutable JSON records. This data is accessible via Google BigQuery and is the most reliable source for forensic investigation — events cannot be deleted or modified after recording.
|
||||
|
||||
## Public Dataset
|
||||
|
||||
- **Project**: `githubarchive`
|
||||
- **Tables**: `day.YYYYMMDD`, `month.YYYYMM`, `year.YYYY`
|
||||
- **Cost**: $6.25 per TiB scanned. Always run dry runs first.
|
||||
- **Access**: Requires a Google Cloud account with BigQuery enabled. Free tier includes 1 TiB/month of queries.
|
||||
|
||||
---
|
||||
|
||||
## The 12 GitHub Event Types
|
||||
|
||||
| Event Type | What It Records | Forensic Value |
|
||||
|------------|-----------------|----------------|
|
||||
| `PushEvent` | Commits pushed to a branch | Force-push detection, commit timeline, author attribution |
|
||||
| `PullRequestEvent` | PR opened, closed, merged, reopened | Deleted PR recovery, review timeline |
|
||||
| `IssuesEvent` | Issue opened, closed, reopened, labeled | Deleted issue recovery, social engineering traces |
|
||||
| `IssueCommentEvent` | Comments on issues and PRs | Deleted comment recovery, communication patterns |
|
||||
| `CreateEvent` | Branch, tag, or repository creation | Suspicious branch creation, tag timing |
|
||||
| `DeleteEvent` | Branch or tag deletion | Evidence of cleanup after compromise |
|
||||
| `MemberEvent` | Collaborator added or removed | Permission changes, access escalation |
|
||||
| `PublicEvent` | Repository made public | Accidental exposure of private repos |
|
||||
| `WatchEvent` | User stars a repository | Actor reconnaissance patterns |
|
||||
| `ForkEvent` | Repository forked | Exfiltration of code before cleanup |
|
||||
| `ReleaseEvent` | Release published, edited, deleted | Malicious release injection, deleted release recovery |
|
||||
| `WorkflowRunEvent` | GitHub Actions workflow triggered | CI/CD abuse, unauthorized workflow runs |
|
||||
|
||||
---
|
||||
|
||||
## Query Templates
|
||||
|
||||
### Basic: All Events for a Repository
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
type,
|
||||
actor.login,
|
||||
repo.name,
|
||||
payload
|
||||
FROM
|
||||
`githubarchive.day.20240101` -- Adjust date
|
||||
WHERE
|
||||
repo.name = 'owner/repo'
|
||||
AND type IN ('PushEvent', 'DeleteEvent', 'MemberEvent')
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Force-Push Detection
|
||||
|
||||
Force-pushes produce PushEvents where commits are overwritten. Key indicators:
|
||||
- `payload.distinct_size = 0` with `payload.size > 0` → commits were erased
|
||||
- `payload.before` contains the SHA before the rewrite (recoverable)
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.before') AS before_sha,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.head') AS after_sha,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.size') AS total_commits,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.distinct_size') AS distinct_commits,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref') AS branch_ref
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'PushEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
AND CAST(JSON_EXTRACT_SCALAR(payload, '$.distinct_size') AS INT64) = 0
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Deleted Branch/Tag Detection
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref') AS deleted_ref,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref_type') AS ref_type
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'DeleteEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Collaborator Permission Changes
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.action') AS action,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.member.login') AS member
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'MemberEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### CI/CD Workflow Activity
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.action') AS action,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.name') AS workflow_name,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.conclusion') AS conclusion,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.head_sha') AS head_sha
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'WorkflowRunEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Actor Activity Profiling
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
type,
|
||||
COUNT(*) AS event_count,
|
||||
MIN(created_at) AS first_event,
|
||||
MAX(created_at) AS last_event
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202301' AND '202412'
|
||||
AND actor.login = 'suspicious-username'
|
||||
GROUP BY type
|
||||
ORDER BY event_count DESC
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cost Optimization (MANDATORY)
|
||||
|
||||
1. **Always dry run first**: Add `--dry_run` flag to `bq query` to see estimated bytes scanned before executing.
|
||||
2. **Use `_TABLE_SUFFIX`**: Narrow the date range as much as possible. `day.*` tables are cheapest for narrow windows; `month.*` for broader sweeps.
|
||||
3. **Select only needed columns**: Avoid `SELECT *`. The `payload` column is large — only select specific JSON paths.
|
||||
4. **Add LIMIT**: Use `LIMIT 1000` during exploration. Remove only for final exhaustive queries.
|
||||
5. **Column filtering in WHERE**: Filter on indexed columns (`type`, `repo.name`, `actor.login`) before payload extraction.
|
||||
|
||||
**Cost estimation**: A single month of GH Archive data is ~1-2 TiB uncompressed. Querying a specific repo + event type with `_TABLE_SUFFIX` typically scans 1-10 GiB ($0.006-$0.06).
|
||||
|
||||
---
|
||||
|
||||
## Accessing via Hermes
|
||||
|
||||
**Option A: BigQuery CLI** (if `gcloud` is installed)
|
||||
```bash
|
||||
bq query --use_legacy_sql=false --format=json "YOUR QUERY"
|
||||
```
|
||||
|
||||
**Option B: Python** (via `execute_code`)
|
||||
```python
|
||||
from google.cloud import bigquery
|
||||
client = bigquery.Client()
|
||||
query = "YOUR QUERY"
|
||||
results = client.query(query).result()
|
||||
for row in results:
|
||||
print(dict(row))
|
||||
```
|
||||
|
||||
**Option C: No GCP credentials available**
|
||||
If BigQuery is unavailable, document this limitation in the report. Use the other 4 investigators (Git, GitHub API, Wayback Machine, IOC Enrichment) — they cover most investigation needs without BigQuery.
|
||||
@@ -0,0 +1,131 @@
|
||||
# Investigation Templates
|
||||
|
||||
Pre-built hypothesis and investigation templates for common supply chain attack scenarios.
|
||||
Each template includes: attack pattern, key evidence to collect, and hypothesis starters.
|
||||
|
||||
---
|
||||
|
||||
## Template 1: Maintainer Account Compromise
|
||||
|
||||
**Pattern**: Attacker gains access to a legitimate maintainer account (phishing, credential stuffing)
|
||||
and uses it to push malicious code, create backdoored releases, or exfiltrate CI secrets.
|
||||
|
||||
**Real-world examples**: XZ Utils (2024), Codecov (2021), event-stream (2018)
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Push events from maintainer account outside normal working hours/timezone
|
||||
- [ ] Commits adding new dependencies, obfuscated code, or modified build scripts
|
||||
- [ ] Release creation immediately after suspicious push (to maximize package distribution)
|
||||
- [ ] MemberEvent adding unknown collaborators (attacker adding backup access)
|
||||
- [ ] WorkflowRunEvent with unexpected secret access or exfiltration-like behavior
|
||||
- [ ] Account login location changes (check social media, conference talks for corroboration)
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Actor <HANDLE>'s account was compromised on or around <DATE>,
|
||||
based on anomalous commit timing [EV-XXXX] and geographic access patterns [EV-YYYY].
|
||||
```
|
||||
```
|
||||
[HYPOTHESIS] Release <VERSION> was published by the compromised account to push
|
||||
malicious code to downstream users, evidenced by the malicious commit [EV-XXXX]
|
||||
being added <N> hours before the release [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 2: Malicious Dependency Injection
|
||||
|
||||
**Pattern**: A trusted package is modified to include malicious code in a dependency,
|
||||
or a new malicious dependency is injected into an existing package.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Diff of `package.json`/`requirements.txt`/`go.mod` before and after suspicious commit
|
||||
- [ ] The new dependency's publication timestamp vs. the injection commit timestamp
|
||||
- [ ] Whether the new dependency exists on npm/PyPI and who owns it
|
||||
- [ ] Any obfuscation patterns in the injected dependency code
|
||||
- [ ] Install-time scripts (`postinstall`, `setup.py`, etc.) that execute code on install
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Commit <SHA> [EV-XXXX] introduced dependency <PACKAGE@VERSION>
|
||||
which appears to be a malicious package published by actor <HANDLE> [EV-YYYY],
|
||||
designed to execute <BEHAVIOR> during installation.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 3: CI/CD Pipeline Injection
|
||||
|
||||
**Pattern**: Attacker modifies GitHub Actions workflows to steal secrets, exfiltrate code,
|
||||
or inject malicious artifacts into the build output.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Diff of all `.github/workflows/*.yml` files before/after suspicious period
|
||||
- [ ] WorkflowRunEvents triggered by the modified workflows
|
||||
- [ ] Any `curl`, `wget`, or network calls added to workflow steps
|
||||
- [ ] New or modified `env:` sections referencing `secrets.*`
|
||||
- [ ] Artifacts produced by modified workflow runs
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Workflow file <FILE> was modified in commit <SHA> [EV-XXXX] to
|
||||
exfiltrate repository secrets via <METHOD>, as evidenced by the added network
|
||||
call pattern [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 4: Typosquatting / Dependency Confusion
|
||||
|
||||
**Pattern**: Attacker registers a package with a name similar to a popular package
|
||||
(or an internal package name) to intercept installs from users who mistype.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Registration timestamp of the suspicious package on the registry
|
||||
- [ ] Package content: does it contain malicious code or is it a stub?
|
||||
- [ ] Download statistics for the suspicious package
|
||||
- [ ] Names of internal packages that could be targeted (if private repo scope)
|
||||
- [ ] Any references to the legitimate package in the malicious one's metadata
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Package <MALICIOUS_NAME> was registered on <DATE> [EV-XXXX] to
|
||||
typosquat on <LEGITIMATE_NAME>, targeting users who misspell the package name.
|
||||
The package contains <BEHAVIOR> [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 5: Force-Push History Rewrite (Evidence Erasure)
|
||||
|
||||
**Pattern**: After a malicious commit is detected (or before wider notice), the attacker
|
||||
force-pushes to remove the malicious commit from branch history.
|
||||
|
||||
**Detection is key** — this template focuses on proving the erasure happened.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] GH Archive PushEvent with `distinct_size=0` (force push indicator) [EV-XXXX]
|
||||
- [ ] The SHA of the commit BEFORE the force push (from GH Archive `payload.before`)
|
||||
- [ ] Recovery of the erased commit via direct URL or `git fetch origin SHA`
|
||||
- [ ] Wayback Machine snapshot of the commit page before erasure
|
||||
- [ ] Timeline gap in git log (N commits visible in archive but M < N in current repo)
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Actor <HANDLE> force-pushed branch <BRANCH> on <DATE> [EV-XXXX]
|
||||
to erase commit <SHA> [EV-YYYY], which contained <MALICIOUS_CONTENT>.
|
||||
The erased commit was recovered via <METHOD> [EV-ZZZZ].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cross-Cutting Investigation Checklist
|
||||
|
||||
Apply to every investigation regardless of template:
|
||||
|
||||
- [ ] Check all contributors for newly created accounts (< 30 days old at time of malicious activity)
|
||||
- [ ] Check if any maintainer account changed email in the period (sign of account takeover)
|
||||
- [ ] Verify GPG signatures on suspicious commits match known maintainer keys
|
||||
- [ ] Check if the repository changed ownership or transferred orgs near the incident
|
||||
- [ ] Look for "cleanup" commits immediately after the malicious commit (cover-up pattern)
|
||||
- [ ] Check related packages/repos by the same author for similar patterns
|
||||
@@ -0,0 +1,164 @@
|
||||
# Deleted Content Recovery Techniques
|
||||
|
||||
## Key Insight: GitHub Never Fully Deletes Force-Pushed Commits
|
||||
|
||||
Force-pushed commits are removed from the branch history but REMAIN on GitHub's servers until garbage collection runs (which can take weeks to months). This is the foundation of deleted commit recovery.
|
||||
|
||||
---
|
||||
|
||||
## Method 1: Direct GitHub URL (Fastest — No Auth Required)
|
||||
|
||||
If you have a commit SHA, access it directly even if it was force-pushed off a branch:
|
||||
|
||||
```bash
|
||||
# View commit metadata
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA"
|
||||
|
||||
# Download as patch (includes full diff)
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA.patch" > recovered_commit.patch
|
||||
|
||||
# Download as diff
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA.diff" > recovered_commit.diff
|
||||
|
||||
# Example (Istio credential leak - real incident):
|
||||
curl -s "https://github.com/istio/istio/commit/FORCE_PUSHED_SHA.patch"
|
||||
```
|
||||
|
||||
**When this works**: SHA is known (from GH Archive, Wayback Machine, or `git fsck`)
|
||||
**When this fails**: GitHub has already garbage-collected the object (rare, typically 30–90 days post-force-push)
|
||||
|
||||
---
|
||||
|
||||
## Method 2: GitHub REST API
|
||||
|
||||
```bash
|
||||
# Works for commits force-pushed off branches but still on server
|
||||
# Note: /commits/SHA may 404, but /git/commits/SHA often succeeds for orphaned commits
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/commits/SHA" | jq .
|
||||
|
||||
# Get the tree (file listing) of a force-pushed commit
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/trees/SHA?recursive=1" | jq .
|
||||
|
||||
# Get a specific file from a force-pushed commit
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contents/PATH?ref=SHA" | jq .content | base64 -d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Method 3: Git Fetch by SHA (Local — Requires Clone)
|
||||
|
||||
```bash
|
||||
# Fetch an orphaned commit directly by SHA into local repo
|
||||
cd target_repo
|
||||
git fetch origin SHA
|
||||
git log FETCH_HEAD -1 # view the commit
|
||||
git diff FETCH_HEAD~1 FETCH_HEAD # view the diff
|
||||
|
||||
# If the SHA was recently force-pushed it will still be fetchable
|
||||
# This stops working once GitHub GC runs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Method 4: Dangling Commits via git fsck
|
||||
|
||||
```bash
|
||||
cd target_repo
|
||||
|
||||
# Find all unreachable objects (includes force-pushed commits)
|
||||
git fsck --unreachable --no-reflogs 2>&1 | grep "unreachable commit" | awk '{print $3}' > dangling_shas.txt
|
||||
|
||||
# For each dangling commit, get its metadata
|
||||
while read sha; do
|
||||
echo "=== $sha ===" >> dangling_details.txt
|
||||
git show --stat "$sha" >> dangling_details.txt 2>&1
|
||||
done < dangling_shas.txt
|
||||
|
||||
# Note: dangling objects only exist in LOCAL clone — not the same as GitHub's copies
|
||||
# GitHub's copies are accessible via Methods 1-3 until GC runs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovering Deleted GitHub Issues and PRs
|
||||
|
||||
### Via Wayback Machine CDX API
|
||||
|
||||
```bash
|
||||
# Find all archived snapshots of a specific issue
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/issues/NUMBER&output=json&limit=50&fl=timestamp,statuscode,original" | python3 -m json.tool
|
||||
|
||||
# Fetch the best snapshot
|
||||
# Use the timestamp from the CDX result:
|
||||
# https://web.archive.org/web/TIMESTAMP/https://github.com/OWNER/REPO/issues/NUMBER
|
||||
curl -s "https://web.archive.org/web/TIMESTAMP/https://github.com/OWNER/REPO/issues/NUMBER" > issue_NUMBER_archived.html
|
||||
|
||||
# Find all snapshots of the repo in a date range
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO*&output=json&from=20240101&to=20240201&limit=200&fl=timestamp,urlkey,statuscode" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Via GitHub API (Limited — Only Non-Deleted Content)
|
||||
|
||||
```bash
|
||||
# Closed issues (not deleted) are retrievable
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/issues?state=closed&per_page=100" | jq '.[].number'
|
||||
|
||||
# Note: DELETED issues/PRs do NOT appear in the API. Use Wayback Machine or GH Archive for those.
|
||||
```
|
||||
|
||||
### Via GitHub Archive (For Event History — Not Content)
|
||||
|
||||
```sql
|
||||
-- Find all IssueEvents for a repo in a date range
|
||||
SELECT created_at, actor.login, payload.action, payload.issue.number, payload.issue.title
|
||||
FROM `githubarchive.day.*`
|
||||
WHERE _TABLE_SUFFIX BETWEEN '20240101' AND '20240201'
|
||||
AND type = 'IssuesEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
ORDER BY created_at
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovering Deleted Files from a Known Commit
|
||||
|
||||
```bash
|
||||
# If you have the commit SHA (even force-pushed):
|
||||
git show SHA:path/to/file.py > recovered_file.py
|
||||
|
||||
# Or via API (base64 encoded content):
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contents/path/to/file.py?ref=SHA" | python3 -c "
|
||||
import sys, json, base64
|
||||
d = json.load(sys.stdin)
|
||||
print(base64.b64decode(d['content']).decode())
|
||||
"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Evidence Recording
|
||||
|
||||
After recovering any deleted content, immediately record it:
|
||||
|
||||
```bash
|
||||
python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json add \
|
||||
--source "git fetch origin FORCE_PUSHED_SHA" \
|
||||
--content "Recovered commit: FORCE_PUSHED_SHA | Author: attacker@example.com | Date: 2024-01-15 | Added file: malicious.sh" \
|
||||
--type git \
|
||||
--actor "attacker-handle" \
|
||||
--url "https://github.com/OWNER/REPO/commit/FORCE_PUSHED_SHA.patch" \
|
||||
--timestamp "2024-01-15T00:00:00Z" \
|
||||
--verification single_source \
|
||||
--notes "Commit force-pushed off main branch on 2024-01-16. Recovered via direct fetch."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovery Failure Modes
|
||||
|
||||
| Failure | Cause | Workaround |
|
||||
|---------|-------|------------|
|
||||
| `git fetch origin SHA` returns "not our ref" | GitHub GC already ran | Try Method 1/2, search Wayback Machine |
|
||||
| `github.com/OWNER/REPO/commit/SHA` returns 404 | GC ran or SHA is wrong | Verify SHA via GH Archive; try partial SHA search |
|
||||
| Wayback Machine has no snapshots | Page was never crawled by IA | Check `commoncrawl.org`, check Google Cache |
|
||||
| BigQuery shows event but no content | GH Archive stores event metadata, not file contents | Recovery only reveals the event occurred, not the content |
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user