For ANS and many other statistical coders (eg. arithmetic coding) you need to create scaled frequencies (the Fs in ANS terminology) from the true counts.
But how do you do that? I've seen many heuristics over the years that are more or less good, but I've never actually seen the right answer. How do you scale to minimize total code len? Well let's do it.
Let's state the problem :
You are given some true counts Cs Sum{Cs} = T the total of true counts the true probabilities then are Ps = Cs/T and the ideal code lens are log2(1/Ps) You need to create scaled frequencies Fs such that Sum{Fs} = M for some given M.and our goal is to minimize the total code len under the counts Fs.
The ideal entropy of the given counts is : H = Sum{ Ps * log2(1/Ps) } The code len under the counts Fs is : L = Sum{ Ps * log2(M/Fs) } The code len is strictly worse than the entropy L >= HWe must also meet the constraint
if ( Cs != 0 ) then Fs > 0That is, all symbols that exist in the set must be codeable. (note that this is not actually optimal; it's usually better to replace all rare symbols with a single escape symbol, but we won't do that here).
The naive solution is :
Fs = round( M * Ps ) if ( Cs > 0 ) Fs = MAX(Fs,1);which is just scaling up the Ps by M. This has two problems  one is that Sum{Fs} is not actually M. The other is that just rounding the float does not actually distribute the integer counts to minimize codelen.
The usual heuristic is to do something like the above, and then apply some fix to make the sum right.
So first let's address how to fix the sum. We will always have issues with the sum being off M because of integer rounding.
What you will have is some correction :
correction = M  Sum{Fs}that can be positive or negative. This is a count that needs to be added onto some symbols. We want to add it to the symbols that give us the most benefit to L, the total code len. Well that's simple, we just measure the affect of changing each Fs :
correction_sign = correction > 0 ? 1 : 1; Ls_before = Ps * log2(M/Fs) Ls_after = Ps * log2(M/(Fs + correction_sign)) Ls_delta = Ls_after  Ls_before Ls_delta = Ps * ( log2(M/(Fs + correction_sign))  log2(M/Fs) ) Ls_delta = Ps * log2(Fs/(Fs + correction_sign))so we need to just find the symbol that gives us the lowest Ls_delta. This is either an improvement to total L, or the least increase in L.
We need to apply multiple corrections. We don't want a solution thats O(alphabet*correction) , since that can be 256*256 in bad cases. (correction is <= alphabet and typically in the 150 range for a typical 256symbol file). The obvious solution is a heap. In pseudocode :
For all s push_heap( Ls_delta , s ) For correction s = pop_heap adjust Fs compute new Ls_delta for s push_heap( Ls_delta , s )note that after we adjust the count we need to recompute Ls_delta and repush that symbol, because we might want to choose the same symbol again later.
In STL+cblib this is :
to[] = Fs
from[] = original counts
struct sort_sym
{
int sym;
float rank;
sort_sym() { }
sort_sym( int s, float r ) : sym(s) , rank(r) { }
bool operator < (const sort_sym & rhs) const { return rank < rhs.rank; }
};

if ( correction != 0 )
{
//lprintfvar(correction);
int32 correction_sign = (correction > 0) ? 1 : 1;
vector

Errkay. So our first attempt is to just use the naive scaling Fs = round( M * Ps ) and then fix the sum using the heap correction algorithm above.
Doing round+correct gets you 99% of the way there. I measured the difference between the total code len made that way and the optimal, and they are less than 0.001 bpb different on every file I tried. But it's still not quite right, so what is the right way?
To guide my search I had a look at the cases where round+correct was not optimal. When it's not optimal it means there is some symbol a and some symbol b such that { Fa+1 , Fb1 } gives a better total code len than {Fa,Fb}. An example of that is :
count to inc : (1/1024) was (1866/1286152 = 0.0015) count to dec : (380/1024) was (482110/1286152 = 0.3748) to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43 to dec; cl before : 1.43 cl after : 1.43 , true cl : 1.42The key point is on the 1 count :
count to inc : (1/1024) was (1866/1286152 = 0.0015) to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43 1024*1866/1286152 = 1.485660 round(1.485660) = 1 so Fs = 1 , which is a codelen of 10 but Fs = 2 gives a codelen (9) closer to the true codelen (9.43)And this provided the key observation : rather than rounding the scaled count, what we should be doing is either floor or ceil of the fraction, whichever gives a codelen closer to the true codelen.
BTW before you go off hacking a special case just for Fs==1, it also happens with higher counts :
count to inc : (2/1024) was (439/180084) scaled = 2.4963 to inc; cl before : 9.00 cl after : 8.42 , true cl : 8.68 count to inc : (4/1024) was (644/146557) scaled = 4.4997 to inc; cl before : 8.00 cl after : 7.68 , true cl : 7.83though obviously the higher Fs, the less likely it is because the rounding gets closer to being perfect.
So it's easy enough just to solve exactly, simply pick the floor or ceil of the ratio depending on which makes the closer codelen :
Ps = Cs/T from the true counts down = floor( M * Ps ) down = MAX( down,1) Fs = either down or (down+1) true_codelen = log2( Ps ) down_codelen = log2( down/M ) up_codelen = log2( (down+1)/M ) if ( down_codelen  true_codelen < up_codelen  true_codelen ) Fs = down else Fs = down+1And since all we care about is the inequality, we can do some maths and simplify the expressions. I won't write out all the algebra to do the simplification because it's straightforward, but there are a few key steps :
 log(x)  = log( MAX(x,1/x) ) log(x) >= log(y) is the same as x >= y down <= M*Ps down+1 >= M*Psthe result of the simplification in code is :
from[] = original counts (Cs) , sum to T to[] = normalized counts (Fs) , will sum to M double from_scaled = from[i] * M/T; uint32 down = (uint32)( from_scaled ); to[i] = ( from_scaled*from_scaled <= down*(down+1) ) ? down : down+1; 
Note that there's no special casing needed to ensure that (from_scaled < 1) gives you to[i] = 1 , we get that for free with this expression.
I was delighted when I got to this extremely simple final form.
And that is the conclusion. Use that to find the initial scaled counts. There will still be some correction that needs to be applied to reach the target sum exactly, so use the heap correction algorithm above.
As a final note, if we look at the final expression :
to[i] = ( from_scaled*from_scaled < down*(down+1) ) ? down : down+1; to[i] = ( test < 0 ) ? down : down+1; test = from_scaled*from_scaled  down*(down+1); from_scaled = down + frac test = (down + frac)^2  down*(down+1); solve for frac where test = 0 frac = sqrt( down^2 + down )  downThat gives you the fractional part of the scaled count where you should round up or down. It varies with floor(from_scaled). The actual values are :
1 : 0.414214 2 : 0.449490 3 : 0.464102 4 : 0.472136 5 : 0.477226 6 : 0.480741 7 : 0.483315 8 : 0.485281 9 : 0.486833 10 : 0.488088 11 : 0.489125 12 : 0.489996 13 : 0.490738 14 : 0.491377 15 : 0.491933 16 : 0.492423 17 : 0.492856 18 : 0.493242 19 : 0.493589You can see as Fs gets larger, it goes to 0.5 , so just using rounding is close to correct. It's really in the very low values where it's quite far from 0.5 that errors are most likely to occur.
No comments:
Post a Comment