Skip to content

Add support for Llama models #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
"codewithkyrian/onnxruntime-downloader-plugin": "^1.1",
"symfony/console": "^6.4|^7.0",
"imagine/imagine": "^1.3",
"rokka/imagine-vips": "^0.31.0",
"spatie/fork": "^1.2"
"rokka/imagine-vips": "^0.31.0"
},
"require-dev": {
"pestphp/pest": "^2.31",
Expand Down
7 changes: 4 additions & 3 deletions examples/pipelines/text-generation.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
ini_set('memory_limit', -1);
//
//$generator = pipeline('text-generation', 'Xenova/gpt2');
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
//
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');

$streamer = StdOutStreamer::make();

$messages = [
['role' => 'system', 'content' => 'You are a helpful assistant.'],
['role' => 'user', 'content' => 'What is diffusion?'],
['role' => 'user', 'content' => 'What is diffusion in chemistry?'],
];

$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
Expand Down
30 changes: 14 additions & 16 deletions src/Decoders/ByteFallback.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public function __construct(array $config)

protected function decodeChain(array $tokens): array
{
$new_tokens = [];
$previous_byte_tokens = [];
$newTokens = [];
$previousByteTokens = [];

foreach ($tokens as $token) {
$bytes = null;
Expand All @@ -30,22 +30,22 @@ protected function decodeChain(array $tokens): array
}
}
if ($bytes !== null) {
$previous_byte_tokens[] = $bytes;
$previousByteTokens[] = $bytes;
} else {
if (count($previous_byte_tokens) > 0) {
$string = $this->bytesToString($previous_byte_tokens);
$new_tokens[] = $string;
$previous_byte_tokens = [];
if (count($previousByteTokens) > 0) {
$string = $this->bytesToString($previousByteTokens);
$newTokens[] = $string;
$previousByteTokens = [];
}
$new_tokens[] = $token;
$newTokens[] = $token;
}
}
if (count($previous_byte_tokens) > 0) {
$string = $this->bytesToString($previous_byte_tokens);
$new_tokens[] = $string;
if (count($previousByteTokens) > 0) {
$string = $this->bytesToString($previousByteTokens);
$newTokens[] = $string;
}

return $new_tokens;
return $newTokens;
}

/**
Expand All @@ -56,9 +56,7 @@ protected function decodeChain(array $tokens): array
*/
protected function bytesToString(array $bytes): string
{
$chars = array_map(function ($byte) {
return chr($byte);
}, $bytes);
return implode('', $chars);
$binaryString = pack('C*', ...$bytes);
return mb_convert_encoding($binaryString, 'ISO-8859-1');
}
}
2 changes: 1 addition & 1 deletion src/Decoders/DecoderSequence.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ protected function decodeChain(array $tokens): array
{
return array_reduce(
$this->decoders,
fn(array $tokens, Decoder $decoder) => $decoder->decode($tokens),
fn(array $tokens, Decoder $decoder) => $decoder->decodeChain($tokens),
$tokens
);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Decoders/FuseDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ public function __construct(array $config)

protected function decodeChain(array $tokens): array
{
return [implode('', $tokens)];
return [implode('', $tokens)];
}
}
9 changes: 8 additions & 1 deletion src/Decoders/ReplaceDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@ protected function decodeChain(array $tokens): array
{
$pattern = $this->config['pattern'] ?? null;


return $pattern == null ?
$tokens :
array_map(function ($token) use ($pattern) {
return str_replace($pattern, $this->config['content'], $token);
if (isset($pattern['Regex'])) {
return preg_replace("/{$pattern['Regex']}/u", $this->config['content'], (string)$token);
} elseif (isset($pattern['String'])) {
return str_replace($pattern['String'], $this->config['content'], (string)$token);
} else {
return $token;
}
}, $tokens);
}
}
4 changes: 2 additions & 2 deletions src/Decoders/StripDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected function decodeChain(array $tokens): array
return array_map(function ($token) {
$startCut = 0;
for ($i = 0; $i < $this->start; ++$i) {
if ($token[$i] === $this->content) {
if ($token[$i] ?? null === $this->content) {
$startCut = $i + 1;
continue;
} else {
Expand All @@ -39,7 +39,7 @@ protected function decodeChain(array $tokens): array
$stopCut = strlen($token);
for ($i = 0; $i < $this->stop; ++$i) {
$index = strlen($token) - $i - 1;
if ($token[$index] === $this->content) {
if ($token[$index] ?? null === $this->content) {
$stopCut = $index;
continue;
} else {
Expand Down
1 change: 1 addition & 0 deletions src/Models/Auto/AutoModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class AutoModel extends PretrainedMixin
"gptj" => \Codewithkyrian\Transformers\Models\Pretrained\GPTJModel::class,
"gpt_bigcode" => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeModel::class,
"codegen" => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenModel::class,
"llama" => \Codewithkyrian\Transformers\Models\Pretrained\LlamaModel::class,
"qwen2" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2Model::class,
];

Expand Down
1 change: 1 addition & 0 deletions src/Models/Auto/AutoModelForCausalLM.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AutoModelForCausalLM extends PretrainedMixin
'gptj' => \Codewithkyrian\Transformers\Models\Pretrained\GPTJForCausalLM::class,
'gpt_bigcode' => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeForCausalLM::class,
'codegen' => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenForCausalLM::class,
'llama' => \Codewithkyrian\Transformers\Models\Pretrained\LlamaForCausalLM::class,
'trocr' => \Codewithkyrian\Transformers\Models\Pretrained\TrOCRForCausalLM::class,
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class
];
Expand Down
2 changes: 2 additions & 0 deletions src/Models/ModelArchitecture.php
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
];


// 2. Run
$output = $model->forward($modelInputs);

// 3. Update
Expand Down
11 changes: 11 additions & 0 deletions src/Models/Pretrained/LlamaForCausalLM.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<?php

declare(strict_types=1);


namespace Codewithkyrian\Transformers\Models\Pretrained;

class LlamaForCausalLM extends LlamaPretrainedModel
{

}
14 changes: 14 additions & 0 deletions src/Models/Pretrained/LlamaModel.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<?php

declare(strict_types=1);


namespace Codewithkyrian\Transformers\Models\Pretrained;

/**
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
*/
class LlamaModel extends LlamaPretrainedModel
{

}
40 changes: 40 additions & 0 deletions src/Models/Pretrained/LlamaPretrainedModel.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<?php

declare(strict_types=1);


namespace Codewithkyrian\Transformers\Models\Pretrained;

use Codewithkyrian\Transformers\Models\ModelArchitecture;
use Codewithkyrian\Transformers\Utils\AutoConfig;
use Codewithkyrian\Transformers\Utils\GenerationConfig;
use Codewithkyrian\Transformers\Utils\InferenceSession;


/**
* The bare LLama Model outputting raw hidden-states without any specific head on top.
*/
class LlamaPretrainedModel extends PretrainedModel
{
protected int $numHeads;
protected int $numLayers;
protected int $dimKv;

public function __construct(
AutoConfig $config,
InferenceSession $session,
public ModelArchitecture $modelArchitecture,
public GenerationConfig $generationConfig
)
{
parent::__construct($config, $session, $modelArchitecture);

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
$this->config['pad_token_id'] = $this->config['eos_token_id'];
$this->config->padTokenId = $this->config['eos_token_id'];

$this->numHeads = $this->config['num_key_value_heads'] ?? $this->config['num_attention_heads'];
$this->numLayers = $this->config['num_hidden_layers'];
$this->dimKv = $this->config['hidden_size'] / $this->config['num_attention_heads'];
}
}
1 change: 0 additions & 1 deletion src/PreTokenizers/MetaspacePreTokenizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public function preTokenizeText(string|array $text, array $options): array
{
$normalized = str_replace(' ', $this->strRep, $text);


$sectionIndex = $options['section_index'] ?? null;

if (
Expand Down
6 changes: 4 additions & 2 deletions src/PretrainedTokenizers/LlamaTokenizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

class LlamaTokenizer extends PretrainedTokenizer
{
const SPIECE_UNDERLINE = "▁";

protected string $defaultChatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\n' + content.strip() + '\n<</SYS>>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}";

public const DEFAULT_SYSTEM_PROMPT =
Expand All @@ -31,7 +33,7 @@ public function __construct(array $tokenizerJSON, array $tokenizerConfig)
// See https://github.com/huggingface/transformers/pull/24565 for more information
$this->normalizer = null;
$this->preTokenizer = new MetaspacePreTokenizer([
'replacement' => '▁',
'replacement' => self::SPIECE_UNDERLINE,
'add_prefix_space' => true,
'prepend_scheme' => 'first',
]);
Expand All @@ -58,7 +60,7 @@ public function encodeText(?string $text, string $textPair = null, bool $addSpec
return parent::encodeText($text, $textPair, $addSpecialTokens);
}

$tokens = parent::encodeText('_' . str_replace('_', ' ', $text));
$tokens = parent::encodeText(self::SPIECE_UNDERLINE . str_replace(self::SPIECE_UNDERLINE, ' ', $text));

if (count($tokens) > 1 && $tokens[0] === '_' && in_array($tokens[1], $this->specialTokens)) {
$tokens = array_slice($tokens, 1);
Expand Down